|
@@ -60,7 +60,7 @@ def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[
|
|
|
|
|
|
# train and dump best model
|
|
|
model_to_train = model_func(**best_param)
|
|
|
- model_to_train.fit(X, y)
|
|
|
+ model_to_train.fit(X_cov, y)
|
|
|
return [feat_extractor, scaler, cov_model, model_to_train]
|
|
|
|
|
|
|
|
@@ -108,7 +108,7 @@ def model_saver(model, model_path, model_type, subject_id, event_id):
|
|
|
if __name__ == '__main__':
|
|
|
# TODO: argparse
|
|
|
subj_name = 'ylj'
|
|
|
- model_type = 'baseline'
|
|
|
+ model_type = 'riemann'
|
|
|
# TODO: load subject config
|
|
|
|
|
|
data_dir = f'./data/{subj_name}/train/'
|