|
@@ -41,26 +41,24 @@ def parse_args():
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
-def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
|
|
|
+def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
|
|
|
"""
|
|
|
"""
|
|
|
events, _ = mne.events_from_annotations(raw, event_id=event_id)
|
|
|
if model_type.lower() == 'baseline':
|
|
|
- model = _train_baseline_model(raw, events, duration=trial_duration)
|
|
|
+ model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs)
|
|
|
elif model_type.lower() == 'riemann':
|
|
|
- # TODO: load subject config
|
|
|
- model = _train_riemann_model(raw, events, duration=trial_duration)
|
|
|
+ model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
return model
|
|
|
|
|
|
|
|
|
-def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
|
|
|
+def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
|
|
|
fs = raw.info['sfreq']
|
|
|
n_ch = len(raw.ch_names)
|
|
|
- feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
|
|
|
+ feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands)
|
|
|
filtered_data = feat_extractor.transform(raw.get_data())
|
|
|
- # TODO: find proper latency
|
|
|
X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
|
|
|
y = events[:, -1]
|
|
|
|
|
@@ -68,7 +66,7 @@ def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)
|
|
|
X = scaler.fit_transform(X)
|
|
|
|
|
|
# compute covariance
|
|
|
- lfb_dim = len(lfb_bands) * n_ch
|
|
|
+ lfb_dim = len(lf_bands) * n_ch
|
|
|
hgs_dim = len(hg_bands) * n_ch
|
|
|
cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
|
|
|
X_cov = cov_model.fit_transform(X)
|
|
@@ -130,12 +128,13 @@ if __name__ == '__main__':
|
|
|
args = parse_args()
|
|
|
subj_name = args.subj
|
|
|
model_type = args.model_type
|
|
|
- # TODO: load subject config
|
|
|
- # include frequency band, model_type, upsampled_trial_duration
|
|
|
|
|
|
data_dir = f'./data/{subj_name}/'
|
|
|
model_dir = './static/models/'
|
|
|
|
|
|
+ with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
|
|
|
+ model_config = yaml.safe_load(f)[model_type]
|
|
|
+
|
|
|
with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
|
|
|
info = yaml.safe_load(f)
|
|
|
sessions = info['sessions']
|
|
@@ -148,7 +147,7 @@ if __name__ == '__main__':
|
|
|
raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
|
|
|
|
|
|
# train model
|
|
|
- model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration)
|
|
|
+ model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
|
|
|
|
|
|
# save
|
|
|
model_saver(model, model_dir, model_type, subj_name, event_id)
|