ソースを参照

Validation若干修复

dk 1 年間 前
コミット
5373b3fa76

+ 6 - 3
backend/bci_core/online.py

@@ -1,6 +1,7 @@
 import joblib
 import numpy as np
 import random
+import logging
 from scipy import signal
 from .feature_extractors import filterbank_extractor
 from .utils import parse_model_type
@@ -28,9 +29,6 @@ class Controller:
             else:
                 raise NotImplementedError
         self.virtual_feedback_rate = virtual_feedback_rate
-    
-    def set_real_feedback_model(self, model):
-        self.real_feedback_model = model
 
     def step_decision(self, data, true_label=None):
         """抓握训练调用接口,只进行单次判决,不涉及马尔可夫过程,
@@ -147,6 +145,7 @@ class HMMModel:
         """
         fs, data = self.parse_data(data)
         p = self.step_probability(fs, data)
+        logging.debug(p, self.probability)
         return self.update_state(p)
     
     def update_state(self, current_p):
@@ -164,6 +163,10 @@ class HMMModel:
                 return current_state
             else:
                 return -1
+    
+    @property
+    def probability(self):
+        return self._probability[self._last_state]
 
 
 class BaselineHMM(HMMModel):

+ 4 - 0
backend/dataloaders/neo.py

@@ -30,6 +30,8 @@ def raw_preprocessing(data_root, session_paths:dict,
         upsampled_epoch_length: 
         ori_epoch_length (int or 'varied'): original epoch length in second
         unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
+        mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
+        rest_trial_ind: only used when unify_label == True, 
     """
     raws_loaded = load_sessions(data_root, session_paths)
     # process event
@@ -62,6 +64,8 @@ def raw_preprocessing(data_root, session_paths:dict,
     raws = mne.concatenate_raws(raws)
     raws.load_data()
 
+    # high pass
+    raws = raws.filter(1, None)
     # filter 50Hz
     raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
 

+ 12 - 3
backend/tests/test_validation.py

@@ -1,9 +1,11 @@
 import unittest
 import os
 import numpy as np
+from glob import glob
+import shutil
 
 from bci_core import utils as ana_utils
-from training import train_model
+from training import train_model, model_saver
 from dataloaders import library_ieeg
 from validation import validation
 
@@ -29,7 +31,14 @@ class TestValidation(unittest.TestCase):
             cls.raw_val.annotations.onset -= t_mid
         
         # train with the first half
-        cls.model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')
+        model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')
+        model_saver(model, './tests/data/', 'baseline', 'test', cls.event_id)
+        cls.model_path = glob(os.path.join('./tests/data/', 'test', '*.pkl'))[0]
+    
+    @classmethod
+    def tearDownClass(cls) -> None:
+        shutil.rmtree(os.path.join('./tests/data/', 'test'))
+        return super().tearDownClass()
     
     def test_event_metric(self):
         event_gt = np.array([[0, 0, 0], [5, 0, 1], [7, 0, 0], [9, 0, 2]])
@@ -41,7 +50,7 @@ class TestValidation(unittest.TestCase):
         self.assertEqual(recall, 1)
 
     def test_validation(self):
-        (precision, recall, f1_score, r), fig_erds, fig_pred = validation(self.raw, 'baseline', self.event_id, model=self.model, state_change_threshold=0.7)
+        (precision, recall, f1_score, r), fig_erds, fig_pred = validation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
         fig_erds.savefig('./tests/data/erds.pdf')
         fig_pred.savefig('./tests/data/pred.pdf')   
 

+ 6 - 2
backend/training.py

@@ -16,6 +16,9 @@ import bci_core.model as bci_model
 from dataloaders import neo
 
 
+logging.basicConfig(level=logging.INFO)
+
+
 def train_model(raw, event_id, model_type='baseline'):
     """
     """
@@ -111,14 +114,15 @@ if __name__ == '__main__':
     data_dir = f'./data/{subj_name}/train/'
     model_dir = './static/models/'
 
-    info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
+    with open(os.path.join(data_dir, 'info.yml'), 'r') as f:
+        info = yaml.safe_load(f)
     sessions = info['sessions']
     event_id = {'rest': 0}
     for f in sessions.keys():
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5)
+    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7])
 
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type)

+ 16 - 21
backend/validation.py

@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
 import mne
 import yaml
 import os
-import joblib
+import logging
 from scipy import stats
 from dataloaders import neo
 import bci_core.online as online
@@ -15,6 +15,9 @@ import bci_core.utils as bci_utils
 import bci_core.viz as bci_viz
 
 
+logging.basicConfig(level=logging.INFO)
+
+
 class DataGenerator:
     def __init__(self, fs, X):
         self.fs = int(fs)
@@ -32,11 +35,10 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
 
 
-def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8):
+def validation(raw_val, event_id, model, state_change_threshold=0.8):
     """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
     Args:
         raw (mne.io.Raw)
-        model_type (string): type of model to train, baseline or riemann
         event_id (dict)
         model: validate existing model, 
         state_change_threshold (float): default 0.8
@@ -58,12 +60,8 @@ def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8)
                                         mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
                                         use_original_label=True)
     
-    if model_type == 'baseline':
-        hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold)
-    else:
-        raise NotImplementedError
-    controller = online.Controller(0, None)
-    controller.set_real_feedback_model(hmm_model)
+    
+    controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
 
     # validate with the second half
     val_data = raw_val.get_data()
@@ -107,27 +105,24 @@ if __name__ == '__main__':
     model_type = 'baseline'
     # TODO: load subject config
 
-    data_dir = f'./data/{subj_name}/val/'
-    model_path = f'./static/models/{subj_name}/scis.pkl'
+    data_dir = f'./data/{subj_name}/train/'
+    model_path = f'./static/models/{subj_name}/baseline_rest+cylinder_11-15-2023-21-34-41.pkl'
 
-    info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
+    with open(os.path.join(data_dir, 'info.yml'), 'r') as f:
+        info = yaml.safe_load(f)
     sessions = info['sessions']
     event_id = {'rest': 0}
     for f in sessions.keys():
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5)
-
-    # load model
-    model = joblib.load(model_path)
-    model_type, events = bci_utils.parse_model_type(model_path)
+    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7])
 
+    # do validations
     metrics, fig_erds, fig_pred = validation(raw, 
-                                             model_type, 
                                              event_id, 
-                                             model=model, 
-                                             state_change_threshold=0.8)
+                                             model=model_path, 
+                                             state_change_threshold=0.95)
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
-    print(metrics)
+    logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')