Browse Source

使用抓握训练范式数据

dk 1 year ago
parent
commit
91b0d6c486
2 changed files with 4 additions and 4 deletions
  1. 1 1
      backend/training.py
  2. 3 3
      backend/validation.py

+ 1 - 1
backend/training.py

@@ -122,7 +122,7 @@ if __name__ == '__main__':
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7])
+    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type)

+ 3 - 3
backend/validation.py

@@ -105,7 +105,7 @@ if __name__ == '__main__':
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/'
-    model_path = f'./static/models/{subj_name}/baseline_rest+cylinder_11-19-2023-17-31-18.pkl'
+    model_path = f'./static/models/{subj_name}/baseline_rest+flex_11-20-2023-19-26-37.pkl'
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
@@ -114,13 +114,13 @@ if __name__ == '__main__':
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7])
+    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
     # do validations
     metrics, fig_erds, fig_pred = validation(raw, 
                                              event_id, 
                                              model=model_path, 
-                                             state_change_threshold=0.8)
+                                             state_change_threshold=0.75)
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
     logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')