Bläddra i källkod

Feat: support bipolar rereference

dk 1 år sedan
förälder
incheckning
20b243d573

+ 1 - 1
backend/band_selection.py

@@ -52,7 +52,7 @@ sessions = info['sessions']
 
 trial_duration = config_info['buffer_length']
 # preprocess raw
-raw, event_id = neo.raw_loader(data_dir, sessions, do_rereference=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
+raw, event_id = neo.raw_loader(data_dir, sessions, reref_method=config_info['reref'], upsampled_epoch_length=trial_duration, ori_epoch_length=5)
 
 ###############################################################################
 # Pipeline with a frequency band selection based on the class distinctiveness

+ 20 - 18
backend/bci_core/online.py

@@ -4,7 +4,7 @@ import random
 import logging
 import os
 from scipy import signal
-from .utils import parse_model_type
+from .utils import parse_model_type, reref
 
 
 logger = logging.getLogger(__name__)
@@ -21,10 +21,12 @@ class Controller:
     """
     def __init__(self,
                  virtual_feedback_rate=1., 
-                 real_feedback_model=None):
+                 real_feedback_model=None,
+                 reref_method='monopolar'):
         
         self.real_feedback_model = real_feedback_model
         self.virtual_feedback_rate = virtual_feedback_rate
+        self.reref_method = reref_method
 
     def step_decision(self, data, true_label=None):
         """抓握训练调用接口,只进行单次判决,不涉及马尔可夫过程,
@@ -41,7 +43,7 @@ class Controller:
             return virtual_feedback
 
         if self.real_feedback_model is not None:
-            fs, data = self.real_feedback_model.parse_data(data)
+            fs, data = self.parse_data(data)
             p = self.real_feedback_model.step_probability(fs, data)
             logger.debug('step_decison: model probability: {}'.format(str(p)))
             pred = np.argmax(p)
@@ -71,7 +73,8 @@ class Controller:
             int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
         """
         if self.real_feedback_model is not None:
-            real_decision = self.real_feedback_model.viterbi(data)
+            fs, data = self.parse_data(data)
+            real_decision = self.real_feedback_model.viterbi(fs, data)
             # map to unified label
             if real_decision != -1:
                 real_decision = self.real_feedback_model.model.classes_[real_decision]
@@ -96,10 +99,20 @@ class Controller:
                 else:
                     return 10000
         return None
+    
+    def parse_data(self, data):
+        fs, event, data_array = data
+        # do preprocessing
+        data_array = reref(data_array, self.reref_method)
+        return fs, data_array
 
 
 class HMMModel:
-    def __init__(self, transmat=None, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.5):
+    def __init__(self, 
+                 transmat=None, 
+                 n_classes=2, 
+                 state_trans_prob=0.6, 
+                 state_change_threshold=0.5):
         self.n_classes = n_classes
         self.set_current_state(0)
 
@@ -124,21 +137,13 @@ class HMMModel:
         self._probability[current_state] = 1.
     
     def step_probability(self, fs, data):
-        # do preprocessing here
-        # common average
-        data -= data.mean(axis=0)
-        return data
-    
-    def parse_data(self, data):
-        fs, event, data_array = data
-        return fs, data_array
+        raise NotImplementedError
     
-    def viterbi(self, data, return_step_p=False):
+    def viterbi(self, fs, data, return_step_p=False):
         """
             Interface for class decision
 
         """
-        fs, data = self.parse_data(data)
         p = self.step_probability(fs, data)
         if return_step_p:
             return p, self.update_state(p)
@@ -164,7 +169,6 @@ class HMMModel:
     
     @property
     def probability(self):
-        # TODO: return each classes
         return self._probability.copy()
 
 
@@ -179,7 +183,6 @@ class BaselineHMM(HMMModel):
     def step_probability(self, fs, data):
         """Step 
         """
-        data = super(BaselineHMM, self).step_probability(fs, data)
         # filter data
         filter_bank_data = self.feat_extractor.transform(data)
         # downsampling
@@ -203,7 +206,6 @@ class RiemannHMM(HMMModel):
     def step_probability(self, fs, data):
         """Step 
         """
-        data = super(RiemannHMM, self).step_probability(fs, data)
         data = self.feat_extractor.transform(data)
         data = data[None]  # pad trial dimension
         # scale data

+ 17 - 0
backend/bci_core/utils.py

@@ -130,3 +130,20 @@ def multiclass_auc_score(y_true, prob, n_classes=None):
     else:
         raise ValueError
     return auc
+
+
+def reref(data, method):
+    data = data.copy()
+    if method == 'average':
+        data -= data.mean(axis=0)
+        return data
+    elif method == 'bipolar':
+        # neo specific
+        anode = data[[0, 1, 2, 3, 7, 6, 5]]
+        cathode = data[[1, 2, 3, 7, 6, 5, 4]]
+        return anode - cathode
+    elif method == 'monopolar':
+        return data
+    else:
+        raise ValueError(f'Rereference method unacceptable, got {str(method)}, expect "monopolar" or "average" or "bipolar"')
+    

+ 20 - 10
backend/dataloaders/neo.py

@@ -14,18 +14,18 @@ CONFIG_INFO = settings.CONFIG_INFO
 
 
 def raw_loader(data_root, session_paths:dict, 
-                      do_rereference=True,
+                      reref_method='monopolar',
                       upsampled_epoch_length=1., 
                       ori_epoch_length=5):
     """
     Params:
         data_root: 
         session_paths: dict of lists
-        do_rereference (bool): do common average rereference or not
+        reref_method (str): rereference method: monopolar, average, or bipolar
         upsampled_epoch_length (None or float): None: do not do upsampling
         ori_epoch_length (int or 'varied'): original epoch length in second
     """
-    raws_loaded = load_sessions(data_root, session_paths, do_rereference)
+    raws_loaded = load_sessions(data_root, session_paths, reref_method)
     # process event
     raws = []
     event_id = {}
@@ -61,11 +61,22 @@ def raw_loader(data_root, session_paths:dict,
     return raws, event_id
 
 
-def preprocessing(raw, do_rereference=True):
+def reref(raw, method='average'):
+    if method == 'average':
+        return raw.set_eeg_reference('average')
+    elif method == 'bipolar':
+        anode = CONFIG_INFO['strips'][0] + CONFIG_INFO['strips'][1][1:][::-1]
+        cathode = CONFIG_INFO['strips'][0][1:] + CONFIG_INFO['strips'][1][::-1]
+        return mne.set_bipolar_reference(raw, anode, cathode)
+    elif method == 'monopolar':
+        return raw
+    else:
+        raise ValueError(f'Rereference method unacceptable, got {str(method)}, expect "monopolar" or "average" or "bipolar"')
+
+
+def preprocessing(raw, reref_method='monopolar'):
     raw.load_data()
-    if do_rereference:
-        # common average
-        raw.set_eeg_reference('average')
+    raw = reref(raw, reref_method)
     # high pass
     raw = raw.filter(1, None)
     # filter 50Hz
@@ -96,7 +107,7 @@ def reconstruct_events(events, fs, trial_duration=5):
     return events_new
 
 
-def load_sessions(data_root, session_names: dict, do_rereference=True):
+def load_sessions(data_root, session_names: dict, reref_method='monopolar'):
     # return raws for different finger models on an interleaved manner
     raw_cnt = sum(len(session_names[k]) for k in session_names)
     raws = []
@@ -116,8 +127,7 @@ def load_sessions(data_root, session_names: dict, do_rereference=True):
                 data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
                 raw = mne.io.read_raw_bdf(data_file)
             # preprocess raw
-            raw = preprocessing(raw, do_rereference)
-
+            raw = preprocessing(raw, reref_method)
             # append list
             raws.append((finger_model, raw))
     return raws  

+ 5 - 4
backend/online_sim.py

@@ -89,8 +89,8 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
     decision_without_hmm = []
     probs = []
     probs_naive = []
-    for time, data in data_gen.loop(step_length):
-        step_p, cls = model_hmm.viterbi(data, return_step_p=True)
+    for time, (fs, event, data) in data_gen.loop(step_length):
+        step_p, cls = model_hmm.viterbi(fs, data, return_step_p=True)
         if cls >=0:
             cls = model_hmm.model.classes_[cls]
         decision_with_hmm.append((time, cls))  # map to unified label
@@ -115,7 +115,7 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
     accu_hmm = accuracy_score(stim_true, stim_pred)
     accu_naive = accuracy_score(stim_true, stim_pred_naive)
     
-    # hmm
+    # hmm plot
     fig_hmm, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
     axes[0].set_title('True states')
     bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_true, ax=axes[0])
@@ -125,7 +125,7 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
         bci_viz.plot_state_prob_with_cue((raw.times[0], raw.times[-1]), stim_true, probs[:, i], ax=ax)
     fig_hmm.suptitle('With HMM')
     
-    # without hmm
+    # naive plot
     fig_naive, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
     axes[0].set_title('True states')
     bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_true, ax=axes[0])
@@ -215,6 +215,7 @@ if __name__ == '__main__':
     # preprocess raw
     trial_time = 5.
     raw, event_id = neo.raw_loader(data_dir, sessions, 
+                                reref_method=config_info['reref'],
                                 ori_epoch_length=trial_time,
                                 upsampled_epoch_length=None)
     

+ 1 - 0
backend/settings/config.py

@@ -27,6 +27,7 @@ class Settings:
         ],
         'strips': [['CH001', 'CH002', 'CH003', 'CH004'], 
                    ['CH005', 'CH006', 'CH007', 'CH008']],
+        'reref': 'bipolar'
     }
     FINGERMODEL_IDS = {
         'rest': 0,

+ 4 - 3
backend/tests/test_online.py

@@ -15,7 +15,7 @@ class TestOnline(unittest.TestCase):
     def setUpClass(cls):
         root_path = './tests/data'
 
-        raw, event_id = neo.raw_loader(root_path, {'flex': ['1']})
+        raw, event_id = neo.raw_loader(root_path, {'flex': ['1']}, reref_method='bipolar')
         
         model = training.train_model(raw, event_id, model_type='baseline')
         
@@ -23,6 +23,7 @@ class TestOnline(unittest.TestCase):
         cls.model_root = os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66')
         cls.model_path = glob(os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66', '*.pkl'))[0]
 
+        raw, event_id = neo.raw_loader(root_path, {'flex': ['1']}, reref_method='monopolar')
         cls.data_gen = DataGenerator(raw.info['sfreq'], raw.get_data())
     
     @classmethod
@@ -32,7 +33,7 @@ class TestOnline(unittest.TestCase):
     
     def test_step_feedback(self):
         model_hmm = online.model_loader(self.model_path)
-        controller = online.Controller(0, model_hmm)
+        controller = online.Controller(0, model_hmm, reref_method='bipolar')
         rets = []
         for time, data in self.data_gen.loop():
             cls = controller.step_decision(data)
@@ -61,7 +62,7 @@ class TestOnline(unittest.TestCase):
 
     def test_real_feedback(self):
         model_hmm = online.model_loader(self.model_path)
-        controller = online.Controller(0, model_hmm)
+        controller = online.Controller(0, model_hmm, reref_method='bipolar')
         rets = []
         for i, (time, data) in zip(range(300), self.data_gen.loop()):
             cls = controller.decision(data)

+ 4 - 2
backend/tests/test_validation.py

@@ -19,8 +19,10 @@ class TestOnlineSim(unittest.TestCase):
     def setUpClass(cls):
         root_path = './tests/data'
 
-        raw_train, cls.event_id = neo.raw_loader(root_path, {'flex': ['1']})
-        cls.raw_val, _ = neo.raw_loader(root_path, {'flex': ['2']}, upsampled_epoch_length=None)
+        raw_train, cls.event_id = neo.raw_loader(root_path, {'flex': ['1']}, reref_method='bipolar')
+        cls.raw_val, _ = neo.raw_loader(root_path, {'flex': ['2']}, 
+                                        upsampled_epoch_length=None,
+                                        reref_method='bipolar')
         
         # train with the first half
         model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')

+ 1 - 1
backend/training.py

@@ -148,7 +148,7 @@ if __name__ == '__main__':
     
     trial_duration = config_info['buffer_length']
     # preprocess raw
-    raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
+    raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5, reref_method=config_info['reref'])
 
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)

+ 1 - 0
backend/validation.py

@@ -116,6 +116,7 @@ if __name__ == '__main__':
     upsampled_trial_duration = config_info['buffer_length']
     raw, event_id = neo.raw_loader(data_dir, sessions, 
                                 ori_epoch_length=trial_time,
+                                reref_method=config_info['reref'],
                                 upsampled_epoch_length=upsampled_trial_duration)
     
     fs = raw.info['sfreq']