Forráskód Böngészése

在线模拟引入hmm训练得到的transmatrix

dk 1 éve
szülő
commit
7252293bd3
2 módosított fájl, 19 hozzáadás és 13 törlés
  1. 14 12
      backend/online_sim.py
  2. 5 1
      backend/tests/test_validation.py

+ 14 - 12
backend/online_sim.py

@@ -125,8 +125,6 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
 
 
 def simulation(raw_val, event_id, model, 
-               state_trans_prob=0.8,
-               state_change_threshold=0.8, 
                epoch_length=1., 
                step_length=0.1,
                event_trial_length=5.):
@@ -135,7 +133,6 @@ def simulation(raw_val, event_id, model,
         raw (mne.io.Raw)
         event_id (dict)
         model: validate existing model, 
-        state_change_threshold (float): default 0.8
         epoch_length (float): batch data length, default 1 (s)
         step_length (float): data step length, default 0.1 (s)
         event_trial_length (float): 
@@ -153,15 +150,11 @@ def simulation(raw_val, event_id, model,
                                         rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], 
                                         mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
                                         use_original_label=True)
-    
-    model_hmm = online.model_loader(model, 
-                                    state_trans_prob=state_trans_prob,
-                                    state_change_threshold=state_change_threshold)
 
     # run with and without hmm
     fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
                                                           events_val, 
-                                                          model_hmm, 
+                                                          model, 
                                                           epoch_length, 
                                                           step_length,
                                                           event_trial_length=event_trial_length)
@@ -194,7 +187,10 @@ if __name__ == '__main__':
 
     data_dir = f'./data/{subj_name}/'
     
-    model_path = f'./static/models/{subj_name}/{args.model_filename}'
+    model_filename = args.model_filename.split('.')[0]
+    model_path = f'./static/models/{subj_name}/{model_filename}.pkl'
+    transmat_path = f'./static/models/{subj_name}/{model_filename}_transmat.txt'
+
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
@@ -210,13 +206,19 @@ if __name__ == '__main__':
                                 mov_trial_ind=[2], 
                                 rest_trial_ind=[1], 
                                 upsampled_epoch_length=None)
+    
+    # load model
+    input_kwargs = {
+        'transmat': transmat_path if os.path.isfile(transmat_path) else None,
+        'state_trans_prob': args.state_trans_prob,
+        'state_change_threshold': args.state_change_threshold
+    }
+    model_hmm = online.model_loader(model_path, **input_kwargs)
 
     # do validations
     metric_hmm, metric_naive, fig_pred = simulation(raw, 
                                              event_id, 
-                                             model=model_path, 
-                                             state_trans_prob=args.state_trans_prob,
-                                             state_change_threshold=args.state_change_threshold,
+                                             model=model_hmm, 
                                              epoch_length=config_info['buffer_length'],
                                              step_length=config_info['buffer_length'],
                                              event_trial_length=trial_time)

+ 5 - 1
backend/tests/test_validation.py

@@ -5,6 +5,7 @@ from glob import glob
 import shutil
 
 from bci_core import utils as ana_utils
+from bci_core.online import model_loader
 from training import train_model, model_saver
 from dataloaders import library_ieeg
 from online_sim import simulation
@@ -51,7 +52,10 @@ class TestOnlineSim(unittest.TestCase):
         self.assertEqual(recall, 1)
 
     def test_sim(self):
-        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7, epoch_length=1., step_length=0.1, state_trans_prob=0.7)
+        model = model_loader(self.model_path, 
+                             state_change_threshold=0.7,
+                             state_trans_prob=0.7)
+        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=model, epoch_length=1., step_length=0.1)
         fig_pred.savefig('./tests/data/pred.pdf')   
 
         print(metric_hmm)