Browse Source

Fix: bug in event reconstruction

dk 1 year ago
parent
commit
01ea619e96
3 changed files with 10 additions and 10 deletions
  1. 8 8
      backend/dataloaders/neo.py
  2. 1 1
      backend/tests/test_neoloader.py
  3. 1 1
      backend/tests/test_training.py

+ 8 - 8
backend/dataloaders/neo.py

@@ -20,14 +20,16 @@ FINGERMODEL_IDS = {
 def raw_preprocessing(data_root, session_paths:dict, 
                       upsampled_epoch_length=1., 
                       ori_epoch_length=5, 
-                      rename_event=True):
+                      unify_label=True,
+                      mov_trial_ind=[2, 3],
+                      rest_trial_ind=[4]):
     """
     Params:
         subj_root: 
         session_paths: dict of lists
         upsampled_epoch_length: 
         ori_epoch_length (int or 'varied'): original epoch length in second
-        rename_event (True, use unified event label, False use original)
+        unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
     """
     raws_loaded = load_sessions(data_root, session_paths)
     # process event
@@ -36,10 +38,8 @@ def raw_preprocessing(data_root, session_paths:dict,
         fs = raw.info['sfreq']
         events, _ = mne.events_from_annotations(raw)
 
-        mov_trial_ind = [2, 3]
-        rest_trial_ind = [4]
-        if not rename_event:
-            mov_trial_ind = [finger_model]
+        if not unify_label:
+            mov_trial_ind = [FINGERMODEL_IDS[finger_model]]
             rest_trial_ind = [0]
         
         if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
@@ -52,7 +52,7 @@ def raw_preprocessing(data_root, session_paths:dict,
                                     mov_trial_ind=mov_trial_ind,
                                     rest_trial_ind=rest_trial_ind,
                                     trial_duration=trial_duration, 
-                                    use_original_label=not rename_event)
+                                    use_original_label=not unify_label)
         
         events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
         annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
@@ -89,7 +89,7 @@ def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind
     else:
         events_new[:, 1] = int(trial_duration * fs)
     events_final = events_new.copy()
-    if not use_original_label and finger_model is not None:
+    if (not use_original_label) and (finger_model is not None):
         # process mov
         ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
         events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model] 

+ 1 - 1
backend/tests/test_neoloader.py

@@ -10,7 +10,7 @@ class TestDataloader(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
-        raw = neo.raw_preprocessing(root_path, sessions)
+        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
         events, event_id = mne.events_from_annotations(raw, event_id=event_id)
         events, events_cnt = np.unique(events[:, -1], return_counts=True)
         self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))

+ 1 - 1
backend/tests/test_training.py

@@ -19,7 +19,7 @@ class TestTraining(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
-        raw = neo.raw_preprocessing(root_path, sessions, rename_event=True)
+        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
         raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
         cls.raw = raw