Bladeren bron

重构NEO loader代码以适配新范式,去除对albatross和kai miller数据集的兼容

dk 1 jaar geleden
bovenliggende
commit
e786dae63b

+ 1 - 4
backend/band_selection.py

@@ -49,13 +49,10 @@ data_dir = f'./data/{subj_name}/'
 with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
     info = yaml.safe_load(f)
 sessions = info['sessions']
-event_id = {'rest': 0}
-for f in sessions.keys():
-    event_id[f] = neo.FINGERMODEL_IDS[f]
 
 trial_duration = config_info['buffer_length']
 # preprocess raw
-raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+raw, event_id = neo.raw_preprocessing(data_dir, sessions, do_rereference=True,upsampled_epoch_length=trial_duration, ori_epoch_length=5)
 
 ###############################################################################
 # Pipeline with a frequency band selection based on the class distinctiveness

+ 0 - 60
backend/dataloaders/library_ieeg.py

@@ -1,60 +0,0 @@
-from scipy import io as sio
-import mne
-import numpy as np
-from .utils import upsample_events
-
-
-# loader for test data
-def raw_preprocessing(data_file, finger_model='cylinder', epoch_time=1., fs=1000):
-    data = sio.loadmat(data_file, simplify_cells=True, squeeze_me=True)
-    # to double
-    raw = data['data'].astype(np.float64).T * 0.0298 * 1e-6  # to V
-    stim_events = data['stim'].astype(np.float64)
-    # deal with line noise
-    raw = mne.filter.notch_filter(raw, fs, [60, 120, 180], trans_bandwidth=3, verbose=False)
-
-    events = extract_events(stim_events, fs)
-    # upsampling 
-    events = upsample_events(events, int(epoch_time * fs))
-
-    info = mne.create_info([f'ch_{i}' for i in range(raw.shape[0])], sfreq=fs, ch_types='ecog')
-
-    # build raw
-    raw = mne.io.RawArray(raw, info)
-
-    annotations = mne.annotations_from_events(events, fs, {1: finger_model, 0: 'rest'})
-    raw.set_annotations(annotations)
-    return raw
-
-
-def extract_events(stim_events, fs=1000.):
-    diff_stim = np.diff(stim_events)
-
-    shift_idx = int(0.5 * fs)  # shift by 500 ms, compensate for reaction time
-
-    # hand only
-    onsets = np.flatnonzero(diff_stim == 12) + shift_idx
-    
-    offsets = np.flatnonzero(diff_stim == -12)
-
-    # handle cut
-    if len(onsets) != len(offsets):
-        # cut first trial
-        if offsets[0] <= onsets[0]:
-            offsets = offsets[1:]
-        # cut last trial
-        else:
-            onsets = onsets[:-1]
-
-    rest_onset = offsets + shift_idx
-    if len(np.unique(offsets - onsets)) > 1:
-        raise ValueError('Unequal trial length?')
-    trial_length = (offsets - onsets)[0]
-
-    # build events
-    events = np.zeros((len(onsets) * 2, 3), dtype=np.int64)
-    events[::2, 0] = onsets
-    events[:, 1] = trial_length
-    events[1::2, 0] = rest_onset
-    events[::2, 2] = 1
-    return events

+ 32 - 50
backend/dataloaders/neo.py

@@ -4,12 +4,11 @@ import json
 import mne
 import glob
 import pyedflib
-from scipy import signal
 from .utils import upsample_events
 from settings.config import settings
 
-
 FINGERMODEL_IDS = settings.FINGERMODEL_IDS
+FINGERMODEL_IDS_INVERSE = settings.FINGERMODEL_IDS_INVERSE
 
 CONFIG_INFO = settings.CONFIG_INFO
 
@@ -17,31 +16,25 @@ CONFIG_INFO = settings.CONFIG_INFO
 def raw_preprocessing(data_root, session_paths:dict, 
                       do_rereference=True,
                       upsampled_epoch_length=1., 
-                      ori_epoch_length=5, 
-                      unify_label=True,
-                      mov_trial_ind=[2, 3],
-                      rest_trial_ind=[4]):
+                      ori_epoch_length=5):
     """
     Params:
-        subj_root: 
+        data_root: 
         session_paths: dict of lists
         do_rereference (bool): do common average rereference or not
         upsampled_epoch_length (None or float): None: do not do upsampling
         ori_epoch_length (int or 'varied'): original epoch length in second
-        unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
-        mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
-        rest_trial_ind: only used when unify_label == True, 
     """
     raws_loaded = load_sessions(data_root, session_paths)
     # process event
     raws = []
+    event_id = {}
     for (finger_model, raw) in raws_loaded:
         fs = raw.info['sfreq']
-        events, _ = mne.events_from_annotations(raw)
+        {d: int(d) for d in np.unique(raw.annotations.description)}
+        events, _ = mne.events_from_annotations(raw, event_id={d: int(d) for d in np.unique(raw.annotations.description)})
 
-        if not unify_label:
-            mov_trial_ind = [FINGERMODEL_IDS[finger_model]]
-            rest_trial_ind = [0]
+        event_id = event_id | {FINGERMODEL_IDS_INVERSE[int(d)]: int(d) for d in np.unique(raw.annotations.description)}
         
         if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
             trial_duration = ori_epoch_length
@@ -49,20 +42,19 @@ def raw_preprocessing(data_root, session_paths:dict,
             trial_duration = None
         else:
             raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
-        events = reconstruct_events(events, fs, finger_model, 
-                                    mov_trial_ind=mov_trial_ind,
-                                    rest_trial_ind=rest_trial_ind,
-                                    trial_duration=trial_duration, 
-                                    use_original_label=not unify_label)
+        events = reconstruct_events(events, fs, 
+                                    trial_duration=trial_duration)
         if upsampled_epoch_length is not None:
             events = upsample_events(events, int(fs * upsampled_epoch_length))
-        annotations = mne.annotations_from_events(events, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
+        
+        event_desc = {e: FINGERMODEL_IDS_INVERSE[e] for e in np.unique(events[:, 2])}
+        annotations = mne.annotations_from_events(events, fs, event_desc)
         raw.set_annotations(annotations)
         raws.append(raw)
 
     raws = mne.concatenate_raws(raws)
-    raws.load_data()
 
+    raws.load_data()
     if do_rereference:
         # common average
         raws.set_eeg_reference('average')
@@ -71,38 +63,30 @@ def raw_preprocessing(data_root, session_paths:dict,
     # filter 50Hz
     raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
 
-    return raws
+    return raws, event_id
 
 
-def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
+def reconstruct_events(events, fs, trial_duration=5):
     """重构出事件序列中的单独运动事件
     Args:
-        fs: int
-        finger_model: 
+        events (np.ndarray): 
+        fs (float):
+        trial_duration (float or None or dict): None means variable epoch length, dict means there are different trial durations for different trials 
     """
     # Trial duration are fixed to be ? seconds.
-    # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
-    # ignore initialRest
     # extract trials
-
-    deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
-    deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
-    trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
+    
+    trials_ind_deduplicated = np.flatnonzero(np.diff(events[:, 2], prepend=0) != 0)
     events_new = events[trials_ind_deduplicated]
     if trial_duration is None:
         events_new[:-1, 1] = np.diff(events_new[:, 0])
         events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
+    elif isinstance(trial_duration, dict):
+        for e in trial_duration.keys():
+            events_new[events_new[:, 2] == e] = trial_duration[e]
     else:
         events_new[:, 1] = int(trial_duration * fs)
-    events_final = events_new.copy()
-    if (not use_original_label) and (finger_model is not None):
-        # process mov
-        ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
-        events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model] 
-        # process rest
-        ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
-        events_final[ind_rest, 2] = 0
-    return events_final
+    return events_new
 
 
 def load_sessions(data_root, session_names: dict):
@@ -117,13 +101,7 @@ def load_sessions(data_root, session_names: dict):
                 i += 1
             except IndexError:
                 continue
-            if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
-                # neo format
-                raw = load_neuracle(os.path.join(data_root, s))
-            else:
-                # kraken format
-                data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
-                raw = mne.io.read_raw_bdf(data_file)
+            raw = load_neuracle(os.path.join(data_root, s))
             raws.append((finger_model, raw))
     return raws  
 
@@ -163,11 +141,15 @@ def load_neuracle(data_dir, data_type='ecog'):
         onset, duration, content = f_evt.readAnnotations()
         onset = np.array(onset) - start_time_point * 1e-3  # correct by start time point
         onset = (onset * sfreq).astype(np.int64)
+        try:
+            content = content.astype(np.int64)  # use original event code
+        except ValueError:
+            event_mapping = {c: i for i, c in enumerate(np.unique(content))}
+            content = [event_mapping[i] for i in content]
 
         duration = (np.array(duration) * sfreq).astype(np.int64) 
-        event_mapping = {c: i for i, c in enumerate(np.unique(content))}
-        event_ids = [event_mapping[i] for i in content]
-        events = np.stack((onset, duration, event_ids), axis=1)
+
+        events = np.stack((onset, duration, content), axis=1)
         
         annotations = mne.annotations_from_events(events, sfreq)
         raw.set_annotations(annotations)

+ 2 - 16
backend/online_sim.py

@@ -135,14 +135,6 @@ def simulation(raw_val, event_id, model,
     fs = raw_val.info['sfreq']
 
     events_val, _ = mne.events_from_annotations(raw_val, event_id)
-        
-    events_val = neo.reconstruct_events(events_val, 
-                                        fs, 
-                                        finger_model=None,
-                                        rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], 
-                                        mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
-                                        use_original_label=True)
-    
     
     controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
     model_hmm = controller.real_feedback_model
@@ -186,17 +178,11 @@ if __name__ == '__main__':
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
-    event_id = {'rest': 0}
-    for f in sessions.keys():
-        event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
     trial_time = 5.
-    raw = neo.raw_preprocessing(data_dir, sessions, 
-                                unify_label=True, 
-                                ori_epoch_length=trial_time, 
-                                mov_trial_ind=[2], 
-                                rest_trial_ind=[1], 
+    raw, event_id = neo.raw_preprocessing(data_dir, sessions, 
+                                ori_epoch_length=trial_time,
                                 upsampled_epoch_length=None)
 
     # do validations

+ 13 - 23
backend/tests/test_neoloader.py

@@ -7,46 +7,36 @@ import numpy as np
 class TestDataloader(unittest.TestCase):
     def test_load_sample_data(self):
         root_path = './tests/data'
-        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
-        event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
+        sessions = {'flex': ['1']}
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
+        raw, event_id = neo.raw_preprocessing(root_path, sessions, upsampled_epoch_length=1.)
         events, event_id = mne.events_from_annotations(raw, event_id=event_id)
         events, events_cnt = np.unique(events[:, -1], return_counts=True)
-        self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))
+        self.assertTrue(np.allclose(events_cnt, (75, 75)))
 
     def test_load_session(self):
         root_path = './tests/data'
-        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4']}
+        sessions = {'flex': ['1', '3'], 'ball': ['2']}
         raws = neo.load_sessions(root_path, sessions)
         # test if interleaved
         sess_f = tuple(f for f, r in raws)
         self.assertEqual(len(raws), 3)
-        self.assertTupleEqual(sess_f, ('cylinder', 'ball', 'cylinder'))
+        self.assertTupleEqual(sess_f, ('flex', 'ball', 'flex'))
 
 
     def test_event_parser(self):
         # fixed length
         fs = 100
         test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
-        gt = np.array([[0, 400, 0], [600, 400, 2], [1000, 400, 0]])
-        ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=4)
-        self.assertTrue(np.allclose(ret, gt))
-        # varing length
-        gt = np.array([[0, 600, 0], [600, 400, 2], [1000, 100, 0]])
-        ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=None)
-        self.assertTrue(np.allclose(ret, gt))
-        # change indices
-        gt = np.array([[0, 400, 2], [600, 400, 0], [1000, 400, 2]])
-        ret = neo.reconstruct_events(test_event, fs, 'ball', mov_trial_ind=[4], rest_trial_ind=[2, 3], trial_duration=4)
-        self.assertTrue(np.allclose(ret, gt))
-        # use original indices
         gt = np.array([[0, 400, 4], [600, 400, 3], [1000, 400, 4]])
-        ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[3], rest_trial_ind=[4])
+        ret = neo.reconstruct_events(test_event, fs, trial_duration=4)
         self.assertTrue(np.allclose(ret, gt))
-        # use original indices, 
-        gt = np.array([[0, 400, 0], [600, 400, 1], [1000, 400, 0]])
-        test_event = np.array([[0, 0, 0], [100, 0, 0], [600, 0, 1], [700, 0, 1], [1000, 0, 0], [1100, 0, 0]])
-        ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[1], rest_trial_ind=[0])
+        # duration as dict
+        gt = np.array([[0, 400, 4], [600, 200, 3], [1000, 400, 4]])
+        trial_duration = {4: 4., 3: 2.}
+        ret = neo.reconstruct_events(test_event, fs, trial_duration=trial_duration)
+        # varing length
+        gt = np.array([[0, 600, 4], [600, 400, 3], [1000, 100, 4]])
+        ret = neo.reconstruct_events(test_event, fs, trial_duration=None)
         self.assertTrue(np.allclose(ret, gt))
 

+ 4 - 6
backend/tests/test_online.py

@@ -3,7 +3,7 @@ import shutil
 import random
 import bci_core.online as online
 import training
-from dataloaders import library_ieeg
+from dataloaders import neo
 from online_sim import DataGenerator
 import unittest
 import numpy as np
@@ -14,10 +14,8 @@ class TestOnline(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         root_path = './tests/data'
-        event_id = {'ball': 2, 'rest': 0}
 
-        raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
-        raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
+        raw, event_id = neo.raw_preprocessing(root_path, {'flex': ['1']})
         
         model = training.train_model(raw, event_id, model_type='baseline')
         
@@ -38,7 +36,7 @@ class TestOnline(unittest.TestCase):
         for time, data in self.data_gen.loop():
             cls = controller.step_decision(data)
             rets.append(cls)
-        self.assertTrue(np.allclose(np.unique(rets), [0, 2]))
+        self.assertTrue(np.allclose(np.unique(rets), [0, 3]))
     
     def test_virtual_feedback(self):
         controller = online.Controller(1, None)
@@ -66,7 +64,7 @@ class TestOnline(unittest.TestCase):
         for i, (time, data) in zip(range(300), self.data_gen.loop()):
             cls = controller.decision(data)
             rets.append(cls)
-        self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 2]))
+        self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 3]))
 
 
 class TestHMM(unittest.TestCase):

+ 3 - 6
backend/tests/test_training.py

@@ -7,20 +7,17 @@ from dataloaders import neo
 from bci_core.feature_extractors import FeatExtractor
 from bci_core.model import baseline_model, riemann_model, ChannelScaler
 import shutil
-from sklearn.utils.validation import check_is_fitted, NotFittedError
+from sklearn.utils.validation import check_is_fitted
 from sklearn.pipeline import Pipeline
-from sklearn.ensemble import StackingClassifier
 
 
 class TestTraining(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         root_path = './tests/data'
-        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
-        cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
+        sessions = {'flex': ['1', '2']}
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
-        raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
+        raw, cls.event_id = neo.raw_preprocessing(root_path, sessions)
         cls.raw = raw
     
     def test_training_baseline(self):

+ 6 - 7
backend/tests/test_validation.py

@@ -6,7 +6,7 @@ import shutil
 
 from bci_core import utils as ana_utils
 from training import train_model, model_saver
-from dataloaders import library_ieeg
+from dataloaders import neo
 from online_sim import simulation
 from validation import val_by_epochs
 
@@ -15,10 +15,8 @@ class TestOnlineSim(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         root_path = './tests/data'
-        cls.event_id = {'ball': 2, 'rest': 0}
 
-        raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
-        raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
+        raw, cls.event_id = neo.raw_preprocessing(root_path, {'flex': ['1', '2']})
         cls.raw = raw
         # split into 2 pieces
         t_min, t_max = raw.times[0], raw.times[-1]
@@ -53,12 +51,13 @@ class TestOnlineSim(unittest.TestCase):
     def test_sim(self):
         metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
         fig_pred.savefig('./tests/data/pred.pdf')   
-
-        self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
-        self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
+        print(metric_hmm, metric_nohmm)
+        self.assertTrue(metric_hmm[-2] > 0.3)  # f1-score (with hmm)
+        self.assertTrue(metric_nohmm[-2] < 0.15)  # f1-score (without hmm)
     
     def test_val_model(self):
         metrices, fig_conf = val_by_epochs(self.raw_val, self.model_path, self.event_id, 1.)
+        print(metrices)
         fig_conf.savefig('./tests/data/conf.pdf')
         self.assertGreater(metrices[0], 0.85)
         self.assertGreater(metrices[1], 0.7)

+ 1 - 4
backend/training.py

@@ -142,13 +142,10 @@ if __name__ == '__main__':
     with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
-    event_id = {'rest': 0}
-    for f in sessions.keys():
-        event_id[f] = neo.FINGERMODEL_IDS[f]
     
     trial_duration = config_info['buffer_length']
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    raw, event_id = neo.raw_preprocessing(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
 
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)

+ 2 - 8
backend/validation.py

@@ -106,18 +106,12 @@ if __name__ == '__main__':
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
-    event_id = {'rest': 0}
-    for f in sessions.keys():
-        event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
     trial_time = 5.
     upsampled_trial_duration = config_info['buffer_length']
-    raw = neo.raw_preprocessing(data_dir, sessions, 
-                                unify_label=True, 
-                                ori_epoch_length=trial_time, 
-                                mov_trial_ind=[2], 
-                                rest_trial_ind=[1],
+    raw, event_id = neo.raw_preprocessing(data_dir, sessions, 
+                                ori_epoch_length=trial_time,
                                 upsampled_epoch_length=upsampled_trial_duration)
     
     fs = raw.info['sfreq']