1
0
Эх сурвалжийг харах

Refactor: 修正变量名和一些注释

dk 1 жил өмнө
parent
commit
015ced6e5e

+ 12 - 9
backend/dataloaders/neo.py

@@ -17,13 +17,16 @@ FINGERMODEL_IDS = {
 }
 
 
-def raw_preprocessing(data_root, session_paths:dict, epoch_time=1., epoch_length=5, rename_event=True):
+def raw_preprocessing(data_root, session_paths:dict, 
+                      upsampled_epoch_length=1., 
+                      ori_epoch_length=5, 
+                      rename_event=True):
     """
     Params:
         subj_root: 
-        session_names: dict of lists
-        epoch_time: 
-        epoch_length (int or 'varied')
+        session_paths: dict of lists
+        upsampled_epoch_length: 
+        ori_epoch_length (int or 'varied'): original epoch length in second
         rename_event (True, use unified event label, False use original)
     """
     raws_loaded = load_sessions(data_root, session_paths)
@@ -39,19 +42,19 @@ def raw_preprocessing(data_root, session_paths:dict, epoch_time=1., epoch_length
             mov_trial_ind = [finger_model]
             rest_trial_ind = [0]
         
-        if isinstance(epoch_length, int):
-            trial_duration = epoch_length
-        elif epoch_length == 'varied':
+        if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
+            trial_duration = ori_epoch_length
+        elif ori_epoch_length == 'varied':
             trial_duration = None
         else:
-            raise ValueError(f'Unsupported epoch_length {epoch_length}')
+            raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
         events = reconstruct_events(events, fs, finger_model, 
                                     mov_trial_ind=mov_trial_ind,
                                     rest_trial_ind=rest_trial_ind,
                                     trial_duration=trial_duration, 
                                     use_original_label=not rename_event)
         
-        events_upsampled = upsample_events(events, int(fs * epoch_time))
+        events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
         annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
         raw.set_annotations(annotations)
         raws.append(raw)