Browse Source

Feat: argparse

dk 1 year ago
parent
commit
0bca28eb54
3 changed files with 83 additions and 8 deletions
  1. 23 0
      .vscode/launch.json
  2. 24 3
      backend/training.py
  3. 36 5
      backend/validation.py

+ 23 - 0
.vscode/launch.json

@@ -34,6 +34,29 @@
             "--model-path", "./static/models/XW01/baseline_rest+flex_11-21-2023-09-59-49.pkl"]
         },
         {
+            "name": "Train model",
+            "type": "python",
+            "request": "launch",
+            "program": "training.py",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/backend",
+            "justMyCode": true,
+            "args": ["--subj", "XW01", 
+            "--model-type", "riemann"]
+        },
+        {
+            "name": "Validate model",
+            "type": "python",
+            "request": "launch",
+            "program": "validation.py",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/backend",
+            "justMyCode": true,
+            "args": ["--subj", "XW01", 
+            "-scth", "0.75",
+            "--model-filename", "riemann_rest+flex_11-23-2023-10-56-58.pkl"]
+        },
+        {
             "name": "Python: 当前文件",
             "type": "python",
             "request": "launch",

+ 24 - 3
backend/training.py

@@ -3,6 +3,7 @@ import joblib
 import os
 from datetime import datetime
 import yaml
+import argparse
 
 import mne
 import numpy as np
@@ -20,6 +21,26 @@ logging.basicConfig(level=logging.INFO)
 config_info = settings.CONFIG_INFO
 
 
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Model validation'
+    )
+    parser.add_argument(
+        '--subj',
+        dest='subj',
+        help='Subject name',
+        default=None,
+        type=str
+    )
+    parser.add_argument(
+        '--model-type',
+        dest='model_type',
+        default='baseline',
+        type=str
+    )
+    return parser.parse_args()
+
+
 def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
     """
     """
@@ -106,9 +127,9 @@ def model_saver(model, model_path, model_type, subject_id, event_id):
 
 
 if __name__ == '__main__':
-    # TODO: argparse
-    subj_name = 'XW01'
-    model_type = 'riemann'
+    args = parse_args()
+    subj_name = args.subj
+    model_type = args.model_type
     # TODO: load subject config
     # include frequency band, model_type, upsampled_trial_duration
 

+ 36 - 5
backend/validation.py

@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
 import mne
 import yaml
 import os
+import argparse
 import logging
 from scipy import stats
 from dataloaders import neo
@@ -20,6 +21,35 @@ logging.basicConfig(level=logging.DEBUG)
 config_info = settings.CONFIG_INFO
 
 
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Model validation'
+    )
+    parser.add_argument(
+        '--subj',
+        dest='subj',
+        help='Subject name',
+        default=None,
+        type=str
+    )
+    parser.add_argument(
+        '--state-change-threshold',
+        '-scth',
+        dest='state_change_threshold',
+        help='Threshold for HMM state change',
+        default=0.75,
+        type=float
+    )
+    parser.add_argument(
+        '--model-filename',
+        dest='model_filename',
+        help='Model filename',
+        default=None,
+        type=str
+    )
+    return parser.parse_args()
+
+
 class DataGenerator:
     def __init__(self, fs, X, epoch_step=1.):
         self.fs = int(fs)
@@ -89,7 +119,7 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
 
     corr, _ = stats.pearsonr(stim_pred, stim_true)
 
-    fig_pred, ax = plt.subplots(1, 1, sharex=True, sharey=False)
+    fig_pred, ax = plt.subplots(3, 1, sharex=True, sharey=False)
     ax[0].set_title('pred')
     ax[0].plot(raw_val.times, stim_pred)
     ax[1].set_title('true')
@@ -118,12 +148,13 @@ def _event_to_stim_channel(events, time_length):
 
 
 if __name__ == '__main__':
-    # TODO: argparse
-    subj_name = 'XW01'
+    args = parse_args()
+    subj_name = args.subj
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/'
-    model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-21-23-15.pkl'
+    
+    model_path = f'./static/models/{subj_name}/{args.model_filename}'
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
@@ -138,7 +169,7 @@ if __name__ == '__main__':
     metrics, fig_erds, fig_pred = validation(raw, 
                                              event_id, 
                                              model=model_path, 
-                                             state_change_threshold=0.75,
+                                             state_change_threshold=args.state_change_threshold,
                                              step_length=config_info['buffer_length'])
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))