Quellcode durchsuchen

Merge branch 'four-state-paradigm'

# Conflicts:
#	backend/band_selection.py
#	backend/dataloaders/neo.py
#	backend/online_sim.py
#	backend/tests/test_neoloader.py
#	backend/tests/test_training.py
#	backend/tests/test_validation.py
#	backend/training.py
#	backend/validation.py
dk vor 1 Jahr
Ursprung
Commit
2a3db55a0c

+ 1 - 4
backend/band_selection.py

@@ -49,13 +49,10 @@ data_dir = f'./data/{subj_name}/'
 with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
     info = yaml.safe_load(f)
 sessions = info['sessions']
-event_id = {'rest': 0}
-for f in sessions.keys():
-    event_id[f] = neo.FINGERMODEL_IDS[f]
 
 trial_duration = config_info['buffer_length']
 # preprocess raw
-raw = neo.raw_loader(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+raw, event_id = neo.raw_loader(data_dir, sessions, do_rereference=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
 
 ###############################################################################
 # Pipeline with a frequency band selection based on the class distinctiveness

+ 0 - 60
backend/dataloaders/library_ieeg.py

@@ -1,60 +0,0 @@
-from scipy import io as sio
-import mne
-import numpy as np
-from .utils import upsample_events
-
-
-# loader for test data
-def raw_preprocessing(data_file, finger_model='cylinder', epoch_time=1., fs=1000):
-    data = sio.loadmat(data_file, simplify_cells=True, squeeze_me=True)
-    # to double
-    raw = data['data'].astype(np.float64).T * 0.0298 * 1e-6  # to V
-    stim_events = data['stim'].astype(np.float64)
-    # deal with line noise
-    raw = mne.filter.notch_filter(raw, fs, [60, 120, 180], trans_bandwidth=3, verbose=False)
-
-    events = extract_events(stim_events, fs)
-    # upsampling 
-    events = upsample_events(events, int(epoch_time * fs))
-
-    info = mne.create_info([f'ch_{i}' for i in range(raw.shape[0])], sfreq=fs, ch_types='ecog')
-
-    # build raw
-    raw = mne.io.RawArray(raw, info)
-
-    annotations = mne.annotations_from_events(events, fs, {1: finger_model, 0: 'rest'})
-    raw.set_annotations(annotations)
-    return raw
-
-
-def extract_events(stim_events, fs=1000.):
-    diff_stim = np.diff(stim_events)
-
-    shift_idx = int(0.5 * fs)  # shift by 500 ms, compensate for reaction time
-
-    # hand only
-    onsets = np.flatnonzero(diff_stim == 12) + shift_idx
-    
-    offsets = np.flatnonzero(diff_stim == -12)
-
-    # handle cut
-    if len(onsets) != len(offsets):
-        # cut first trial
-        if offsets[0] <= onsets[0]:
-            offsets = offsets[1:]
-        # cut last trial
-        else:
-            onsets = onsets[:-1]
-
-    rest_onset = offsets + shift_idx
-    if len(np.unique(offsets - onsets)) > 1:
-        raise ValueError('Unequal trial length?')
-    trial_length = (offsets - onsets)[0]
-
-    # build events
-    events = np.zeros((len(onsets) * 2, 3), dtype=np.int64)
-    events[::2, 0] = onsets
-    events[:, 1] = trial_length
-    events[1::2, 0] = rest_onset
-    events[::2, 2] = 1
-    return events

+ 31 - 42
backend/dataloaders/neo.py

@@ -4,12 +4,11 @@ import json
 import mne
 import glob
 import pyedflib
-from scipy import signal
 from .utils import upsample_events
 from settings.config import settings
 
-
 FINGERMODEL_IDS = settings.FINGERMODEL_IDS
+FINGERMODEL_IDS_INVERSE = settings.FINGERMODEL_IDS_INVERSE
 
 CONFIG_INFO = settings.CONFIG_INFO
 
@@ -17,31 +16,25 @@ CONFIG_INFO = settings.CONFIG_INFO
 def raw_loader(data_root, session_paths:dict, 
                       do_rereference=True,
                       upsampled_epoch_length=1., 
-                      ori_epoch_length=5, 
-                      unify_label=True,
-                      mov_trial_ind=[2, 3],
-                      rest_trial_ind=[4]):
+                      ori_epoch_length=5):
     """
     Params:
-        subj_root: 
+        data_root: 
         session_paths: dict of lists
         do_rereference (bool): do common average rereference or not
         upsampled_epoch_length (None or float): None: do not do upsampling
         ori_epoch_length (int or 'varied'): original epoch length in second
-        unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
-        mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
-        rest_trial_ind: only used when unify_label == True, 
     """
     raws_loaded = load_sessions(data_root, session_paths, do_rereference)
     # process event
     raws = []
+    event_id = {}
     for (finger_model, raw) in raws_loaded:
         fs = raw.info['sfreq']
-        events, _ = mne.events_from_annotations(raw)
+        {d: int(d) for d in np.unique(raw.annotations.description)}
+        events, _ = mne.events_from_annotations(raw, event_id={d: int(d) for d in np.unique(raw.annotations.description)})
 
-        if not unify_label:
-            mov_trial_ind = [FINGERMODEL_IDS[finger_model]]
-            rest_trial_ind = [0]
+        event_id = event_id | {FINGERMODEL_IDS_INVERSE[int(d)]: int(d) for d in np.unique(raw.annotations.description)}
         
         if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
             trial_duration = ori_epoch_length
@@ -49,21 +42,21 @@ def raw_loader(data_root, session_paths:dict,
             trial_duration = None
         else:
             raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
-        events = reconstruct_events(events, fs, finger_model, 
-                                    mov_trial_ind=mov_trial_ind,
-                                    rest_trial_ind=rest_trial_ind,
-                                    trial_duration=trial_duration, 
-                                    use_original_label=not unify_label)
+        events = reconstruct_events(events, fs, 
+                                    trial_duration=trial_duration)
         if upsampled_epoch_length is not None:
             events = upsample_events(events, int(fs * upsampled_epoch_length))
-        annotations = mne.annotations_from_events(events, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
+        
+        event_desc = {e: FINGERMODEL_IDS_INVERSE[e] for e in np.unique(events[:, 2])}
+        annotations = mne.annotations_from_events(events, fs, event_desc)
         raw.set_annotations(annotations)
         raws.append(raw)
 
     raws = mne.concatenate_raws(raws)
+
     raws.load_data()
 
-    return raws
+    return raws, event_id
 
 
 def preprocessing(raw, do_rereference=True):
@@ -78,35 +71,27 @@ def preprocessing(raw, do_rereference=True):
     return raw
 
 
-def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
+def reconstruct_events(events, fs, trial_duration=5):
     """重构出事件序列中的单独运动事件
     Args:
-        fs: int
-        finger_model: 
+        events (np.ndarray): 
+        fs (float):
+        trial_duration (float or None or dict): None means variable epoch length, dict means there are different trial durations for different trials 
     """
     # Trial duration are fixed to be ? seconds.
-    # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
-    # ignore initialRest
     # extract trials
-
-    deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
-    deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
-    trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
+    
+    trials_ind_deduplicated = np.flatnonzero(np.diff(events[:, 2], prepend=0) != 0)
     events_new = events[trials_ind_deduplicated]
     if trial_duration is None:
         events_new[:-1, 1] = np.diff(events_new[:, 0])
         events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
+    elif isinstance(trial_duration, dict):
+        for e in trial_duration.keys():
+            events_new[events_new[:, 2] == e] = trial_duration[e]
     else:
         events_new[:, 1] = int(trial_duration * fs)
-    events_final = events_new.copy()
-    if (not use_original_label) and (finger_model is not None):
-        # process mov
-        ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
-        events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model] 
-        # process rest
-        ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
-        events_final[ind_rest, 2] = 0
-    return events_final
+    return events_new
 
 
 def load_sessions(data_root, session_names: dict, do_rereference=True):
@@ -171,11 +156,15 @@ def load_neuracle(data_dir, data_type='ecog'):
         onset, duration, content = f_evt.readAnnotations()
         onset = np.array(onset) - start_time_point * 1e-3  # correct by start time point
         onset = (onset * sfreq).astype(np.int64)
+        try:
+            content = content.astype(np.int64)  # use original event code
+        except ValueError:
+            event_mapping = {c: i for i, c in enumerate(np.unique(content))}
+            content = [event_mapping[i] for i in content]
 
         duration = (np.array(duration) * sfreq).astype(np.int64) 
-        event_mapping = {c: i for i, c in enumerate(np.unique(content))}
-        event_ids = [event_mapping[i] for i in content]
-        events = np.stack((onset, duration, event_ids), axis=1)
+
+        events = np.stack((onset, duration, content), axis=1)
         
         annotations = mne.annotations_from_events(events, sfreq)
         raw.set_annotations(annotations)

+ 12 - 41
backend/device/fubo_pneumatic_finger.py

@@ -25,15 +25,16 @@ def get_serial_ports():
 
 class FuboPneumaticFingerClient:
     """富伯客户端"""
-
-    FLEX_CMD = b"F"
-    EXTEND_CMD = b"E"
-    BALL_CMD = b"B"
-    CYLINDER_CMD = b"C"
-    DOUBLE_CMD = b"D"
-    TREBLE_CMD = b"T"
-    RELEASE_CMD = b"R"
-
+    
+    COMMAND_TABLE = {
+        'release': b"R",
+        'cylinder': b"C",
+        'ball': b"B",
+        'flex': b"F",
+        'double': b"D",
+        'treble': b"T",
+        'extend': b"E",
+    }
     def __init__(self, init_params=None):
         self.baud_rate = 9600
         self.data_bite = 8
@@ -71,40 +72,10 @@ class FuboPneumaticFingerClient:
             logger.warning(warning_info)
             return 0
     
-    def release(self):
-        self.ser.write(self.RELEASE_CMD)
-        return self.ser.read()
-
-    def extend(self):
-        self.ser.write(self.EXTEND_CMD)
-        return self.ser.read()
-
-    def reconnect(self):
-        self.close()
-        return self.connect()
-    
-    def start(self, model=None):
-        if (model == "flex") or (model is None):
-            self.ser.write(self.FLEX_CMD)
-        elif model == "ball":
-            self.ser.write(self.BALL_CMD)
-        elif model == "cylinder":
-            self.ser.write(self.CYLINDER_CMD)
-        elif model == "double":
-            self.ser.write(self.DOUBLE_CMD)
-        elif model == "treble":
-            self.ser.write(self.TREBLE_CMD)
+    def start(self, command):
+        self.ser.write(self.COMMAND_TABLE[command])
         return self.ser.read()
 
-    def start_round(self, model=None, time_interval=5):        
-        self.start(model)
-        time.sleep(time_interval)
-        self.extend()
-        return 1
-
-    def stop(self):
-        return 1
-
     def status(self):
         status = {"is_connected": self.is_connected}
         return status

+ 2 - 15
backend/online_sim.py

@@ -143,13 +143,6 @@ def simulation(raw_val, event_id, model,
     fs = raw_val.info['sfreq']
 
     events_val, _ = mne.events_from_annotations(raw_val, event_id)
-        
-    events_val = neo.reconstruct_events(events_val, 
-                                        fs, 
-                                        finger_model=None,
-                                        rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], 
-                                        mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
-                                        use_original_label=True)
 
     # run with and without hmm
     fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
@@ -192,17 +185,11 @@ if __name__ == '__main__':
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
-    event_id = {'rest': 0}
-    for f in sessions.keys():
-        event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
     trial_time = 5.
-    raw = neo.raw_loader(data_dir, sessions, 
-                                unify_label=True, 
-                                ori_epoch_length=trial_time, 
-                                mov_trial_ind=[2], 
-                                rest_trial_ind=[1], 
+    raw, event_id = neo.raw_loader(data_dir, sessions, 
+                                ori_epoch_length=trial_time,
                                 upsampled_epoch_length=None)
     
     # load model

+ 4 - 2
backend/settings/config.py

@@ -34,7 +34,8 @@ class Settings:
         'ball': 2,
         'flex': 3,
         'double': 4,
-        'treble': 5
+        'treble': 5,
+        'extend': 6
     }
     FINGERMODEL_IDS_INVERSE = {
         0: 'rest',
@@ -42,7 +43,8 @@ class Settings:
         2: 'ball',
         3: 'flex',
         4: 'double',
-        5: 'treble'
+        5: 'treble',
+        6: 'extend'
     }
     PROJECT_VERSION: str = '0.0.1'
     DATA_PATH = './data'

+ 13 - 23
backend/tests/test_neoloader.py

@@ -7,46 +7,36 @@ import numpy as np
 class TestDataloader(unittest.TestCase):
     def test_load_sample_data(self):
         root_path = './tests/data'
-        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
-        event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
+        sessions = {'flex': ['1']}
 
-        raw = neo.raw_loader(root_path, sessions, unify_label=True)
+        raw, event_id = neo.raw_loader(root_path, sessions, upsampled_epoch_length=1.)
         events, event_id = mne.events_from_annotations(raw, event_id=event_id)
         events, events_cnt = np.unique(events[:, -1], return_counts=True)
-        self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))
+        self.assertTrue(np.allclose(events_cnt, (75, 75)))
 
     def test_load_session(self):
         root_path = './tests/data'
-        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4']}
+        sessions = {'flex': ['1', '3'], 'ball': ['2']}
         raws = neo.load_sessions(root_path, sessions)
         # test if interleaved
         sess_f = tuple(f for f, r in raws)
         self.assertEqual(len(raws), 3)
-        self.assertTupleEqual(sess_f, ('cylinder', 'ball', 'cylinder'))
+        self.assertTupleEqual(sess_f, ('flex', 'ball', 'flex'))
 
 
     def test_event_parser(self):
         # fixed length
         fs = 100
         test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
-        gt = np.array([[0, 400, 0], [600, 400, 2], [1000, 400, 0]])
-        ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=4)
-        self.assertTrue(np.allclose(ret, gt))
-        # varing length
-        gt = np.array([[0, 600, 0], [600, 400, 2], [1000, 100, 0]])
-        ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=None)
-        self.assertTrue(np.allclose(ret, gt))
-        # change indices
-        gt = np.array([[0, 400, 2], [600, 400, 0], [1000, 400, 2]])
-        ret = neo.reconstruct_events(test_event, fs, 'ball', mov_trial_ind=[4], rest_trial_ind=[2, 3], trial_duration=4)
-        self.assertTrue(np.allclose(ret, gt))
-        # use original indices
         gt = np.array([[0, 400, 4], [600, 400, 3], [1000, 400, 4]])
-        ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[3], rest_trial_ind=[4])
+        ret = neo.reconstruct_events(test_event, fs, trial_duration=4)
         self.assertTrue(np.allclose(ret, gt))
-        # use original indices, 
-        gt = np.array([[0, 400, 0], [600, 400, 1], [1000, 400, 0]])
-        test_event = np.array([[0, 0, 0], [100, 0, 0], [600, 0, 1], [700, 0, 1], [1000, 0, 0], [1100, 0, 0]])
-        ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[1], rest_trial_ind=[0])
+        # duration as dict
+        gt = np.array([[0, 400, 4], [600, 200, 3], [1000, 400, 4]])
+        trial_duration = {4: 4., 3: 2.}
+        ret = neo.reconstruct_events(test_event, fs, trial_duration=trial_duration)
+        # varing length
+        gt = np.array([[0, 600, 4], [600, 400, 3], [1000, 100, 4]])
+        ret = neo.reconstruct_events(test_event, fs, trial_duration=None)
         self.assertTrue(np.allclose(ret, gt))
 

+ 4 - 6
backend/tests/test_online.py

@@ -3,7 +3,7 @@ import shutil
 import random
 import bci_core.online as online
 import training
-from dataloaders import library_ieeg
+from dataloaders import neo
 from online_sim import DataGenerator
 import unittest
 import numpy as np
@@ -14,10 +14,8 @@ class TestOnline(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         root_path = './tests/data'
-        event_id = {'ball': 2, 'rest': 0}
 
-        raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
-        raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
+        raw, event_id = neo.raw_loader(root_path, {'flex': ['1']})
         
         model = training.train_model(raw, event_id, model_type='baseline')
         
@@ -39,7 +37,7 @@ class TestOnline(unittest.TestCase):
         for time, data in self.data_gen.loop():
             cls = controller.step_decision(data)
             rets.append(cls)
-        self.assertTrue(np.allclose(np.unique(rets), [0, 2]))
+        self.assertTrue(np.allclose(np.unique(rets), [0, 3]))
     
     def test_virtual_feedback(self):
         controller = online.Controller(1, None)
@@ -68,7 +66,7 @@ class TestOnline(unittest.TestCase):
         for i, (time, data) in zip(range(300), self.data_gen.loop()):
             cls = controller.decision(data)
             rets.append(cls)
-        self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 2]))
+        self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 3]))
 
 
 class TestHMM(unittest.TestCase):

+ 1 - 9
backend/tests/test_peripheral_hand.py

@@ -16,22 +16,14 @@ class TestPeripheralHand(unittest.TestCase):
         self.assertTrue(client.is_connected)
         client.close()
 
-
     def test_client_close_success(self):
         client = FuboPneumaticFingerClient(init_params)
         client.close()
         self.assertFalse(client.is_connected)
 
-    def test_start_flex_and_extend_success(self):
-        client = FuboPneumaticFingerClient(init_params)
-        client.start_round(time_interval=7)
-        time.sleep(3)
-        client.close()
-
-
     def test_start_extend_success(self):
         client = FuboPneumaticFingerClient(init_params)
-        client.extend()
+        client.start('extend')
         time.sleep(3)
         client.close()
 

+ 3 - 6
backend/tests/test_training.py

@@ -7,20 +7,17 @@ from dataloaders import neo
 from bci_core.feature_extractors import FeatExtractor
 from bci_core.model import baseline_model, riemann_model, ChannelScaler
 import shutil
-from sklearn.utils.validation import check_is_fitted, NotFittedError
+from sklearn.utils.validation import check_is_fitted
 from sklearn.pipeline import Pipeline
-from sklearn.ensemble import StackingClassifier
 
 
 class TestTraining(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         root_path = './tests/data'
-        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
-        cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
+        sessions = {'flex': ['1', '2']}
 
-        raw = neo.raw_loader(root_path, sessions, unify_label=True)
-        raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
+        raw, cls.event_id = neo.raw_loader(root_path, sessions)
         cls.raw = raw
     
     def test_training_baseline(self):

+ 4 - 8
backend/tests/test_validation.py

@@ -7,7 +7,7 @@ import shutil
 from bci_core import utils as ana_utils
 from bci_core.online import model_loader
 from training import train_model, model_saver
-from dataloaders import library_ieeg
+from dataloaders import neo
 from online_sim import simulation
 from validation import val_by_epochs
 
@@ -16,10 +16,8 @@ class TestOnlineSim(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         root_path = './tests/data'
-        cls.event_id = {'ball': 2, 'rest': 0}
 
-        raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
-        raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
+        raw, cls.event_id = neo.raw_loader(root_path, {'flex': ['1', '2']})
         cls.raw = raw
         # split into 2 pieces
         t_min, t_max = raw.times[0], raw.times[-1]
@@ -57,10 +55,8 @@ class TestOnlineSim(unittest.TestCase):
                              state_trans_prob=0.7)
         metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=model, epoch_length=1., step_length=0.1)
         fig_pred.savefig('./tests/data/pred.pdf')   
-
-        print(metric_hmm)
-        self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
-        self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
+        self.assertTrue(metric_hmm[-2] > 0.3)  # f1-score (with hmm)
+        self.assertTrue(metric_nohmm[-2] < 0.15)  # f1-score (without hmm)
     
     def test_val_model(self):
         metrices, fig_conf = val_by_epochs(self.raw_val, self.model_path, self.event_id, 1.)

+ 1 - 4
backend/training.py

@@ -142,13 +142,10 @@ if __name__ == '__main__':
     with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
-    event_id = {'rest': 0}
-    for f in sessions.keys():
-        event_id[f] = neo.FINGERMODEL_IDS[f]
     
     trial_duration = config_info['buffer_length']
     # preprocess raw
-    raw = neo.raw_loader(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
 
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)

+ 2 - 8
backend/validation.py

@@ -107,18 +107,12 @@ if __name__ == '__main__':
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
-    event_id = {'rest': 0}
-    for f in sessions.keys():
-        event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
     trial_time = 5.
     upsampled_trial_duration = config_info['buffer_length']
-    raw = neo.raw_loader(data_dir, sessions, 
-                                unify_label=True, 
-                                ori_epoch_length=trial_time, 
-                                mov_trial_ind=[2], 
-                                rest_trial_ind=[1],
+    raw, event_id = neo.raw_loader(data_dir, sessions, 
+                                ori_epoch_length=trial_time,
                                 upsampled_epoch_length=upsampled_trial_duration)
     
     fs = raw.info['sfreq']