Przeglądaj źródła

修正:trial duration为字典时的错误

dk 1 rok temu
rodzic
commit
5d1d1e1881
2 zmienionych plików z 4 dodań i 1 usunięć
  1. 3 1
      backend/dataloaders/neo.py
  2. 1 0
      backend/tests/test_neoloader.py

+ 3 - 1
backend/dataloaders/neo.py

@@ -40,6 +40,8 @@ def raw_loader(data_root, session_paths:dict,
             trial_duration = ori_epoch_length
         elif ori_epoch_length == 'varied':
             trial_duration = None
+        elif isinstance(ori_epoch_length, dict):
+            trial_duration = ori_epoch_length
         else:
             raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
         events = reconstruct_events(events, fs, 
@@ -88,7 +90,7 @@ def reconstruct_events(events, fs, trial_duration=5):
         events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
     elif isinstance(trial_duration, dict):
         for e in trial_duration.keys():
-            events_new[events_new[:, 2] == e] = trial_duration[e]
+            events_new[events_new[:, 2] == e, 1] = int(trial_duration[e] * fs)
     else:
         events_new[:, 1] = int(trial_duration * fs)
     return events_new

+ 1 - 0
backend/tests/test_neoloader.py

@@ -35,6 +35,7 @@ class TestDataloader(unittest.TestCase):
         gt = np.array([[0, 400, 4], [600, 200, 3], [1000, 400, 4]])
         trial_duration = {4: 4., 3: 2.}
         ret = neo.reconstruct_events(test_event, fs, trial_duration=trial_duration)
+        self.assertTrue(np.allclose(ret, gt))
         # varing length
         gt = np.array([[0, 600, 4], [600, 400, 3], [1000, 100, 4]])
         ret = neo.reconstruct_events(test_event, fs, trial_duration=None)