Browse Source

增加model config

dk 1 year ago
parent
commit
bb87fec41e
2 changed files with 10 additions and 12 deletions
  1. 10 11
      backend/training.py
  2. 0 1
      backend/validation.py

+ 10 - 11
backend/training.py

@@ -41,26 +41,24 @@ def parse_args():
     return parser.parse_args()
 
 
-def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
+def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
     """
     """
     events, _ = mne.events_from_annotations(raw, event_id=event_id)
     if model_type.lower() == 'baseline':
-        model = _train_baseline_model(raw, events, duration=trial_duration)
+        model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs)
     elif model_type.lower() == 'riemann':
-        # TODO: load subject config
-        model = _train_riemann_model(raw, events, duration=trial_duration)
+        model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs)
     else:
         raise NotImplementedError
     return model
 
 
-def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
+def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
     fs = raw.info['sfreq']
     n_ch = len(raw.ch_names)
-    feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
+    feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands)
     filtered_data = feat_extractor.transform(raw.get_data())
-    # TODO: find proper latency
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
     y = events[:, -1]
     
@@ -68,7 +66,7 @@ def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)
     X = scaler.fit_transform(X)
 
     # compute covariance
-    lfb_dim = len(lfb_bands) * n_ch
+    lfb_dim = len(lf_bands) * n_ch
     hgs_dim = len(hg_bands) * n_ch
     cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
     X_cov = cov_model.fit_transform(X)
@@ -130,12 +128,13 @@ if __name__ == '__main__':
     args = parse_args()
     subj_name = args.subj
     model_type = args.model_type
-    # TODO: load subject config
-    # include frequency band, model_type, upsampled_trial_duration
 
     data_dir = f'./data/{subj_name}/'
     model_dir = './static/models/'
 
+    with open(os.path.join(data_dir, 'config.yml'), 'r') as f:
+        model_config = yaml.safe_load(f)[model_type]
+
     with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
@@ -148,7 +147,7 @@ if __name__ == '__main__':
     raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
     # train model
-    model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration)
+    model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
     
     # save
     model_saver(model, model_dir, model_type, subj_name, event_id)

+ 0 - 1
backend/validation.py

@@ -150,7 +150,6 @@ def _event_to_stim_channel(events, time_length):
 if __name__ == '__main__':
     args = parse_args()
     subj_name = args.subj
-    # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/'