Browse Source

Fix: training error

dk 1 year ago
parent
commit
a7f586f8d2
1 changed files with 2 additions and 2 deletions
  1. 2 2
      backend/training.py

+ 2 - 2
backend/training.py

@@ -60,7 +60,7 @@ def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[
 
     # train and dump best model
     model_to_train = model_func(**best_param)
-    model_to_train.fit(X, y)
+    model_to_train.fit(X_cov, y)
     return [feat_extractor, scaler, cov_model, model_to_train]
 
 
@@ -108,7 +108,7 @@ def model_saver(model, model_path, model_type, subject_id, event_id):
 if __name__ == '__main__':
     # TODO: argparse
     subj_name = 'ylj'
-    model_type = 'baseline'
+    model_type = 'riemann'
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/train/'