test_neoloader.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import unittest
  2. from dataloaders import neo
  3. import mne
  4. import numpy as np
  5. class TestDataloader(unittest.TestCase):
  6. def test_load_sample_data(self):
  7. root_path = './tests/data'
  8. sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
  9. event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
  10. raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
  11. events, event_id = mne.events_from_annotations(raw, event_id=event_id)
  12. events, events_cnt = np.unique(events[:, -1], return_counts=True)
  13. self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))
  14. def test_load_session(self):
  15. root_path = './tests/data'
  16. sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4']}
  17. raws = neo.load_sessions(root_path, sessions)
  18. # test if interleaved
  19. sess_f = tuple(f for f, r in raws)
  20. self.assertEqual(len(raws), 3)
  21. self.assertTupleEqual(sess_f, ('cylinder', 'ball', 'cylinder'))
  22. def test_event_parser(self):
  23. # fixed length
  24. fs = 100
  25. test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
  26. gt = np.array([[0, 400, 0], [600, 400, 2], [1000, 400, 0]])
  27. ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=4)
  28. self.assertTrue(np.allclose(ret, gt))
  29. # varing length
  30. gt = np.array([[0, 600, 0], [600, 400, 2], [1000, 100, 0]])
  31. ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=None)
  32. self.assertTrue(np.allclose(ret, gt))
  33. # change indices
  34. gt = np.array([[0, 400, 2], [600, 400, 0], [1000, 400, 2]])
  35. ret = neo.reconstruct_events(test_event, fs, 'ball', mov_trial_ind=[4], rest_trial_ind=[2, 3], trial_duration=4)
  36. self.assertTrue(np.allclose(ret, gt))
  37. # use original indices
  38. gt = np.array([[0, 400, 4], [600, 400, 3], [1000, 400, 4]])
  39. ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[3], rest_trial_ind=[4])
  40. self.assertTrue(np.allclose(ret, gt))
  41. # use original indices,
  42. gt = np.array([[0, 400, 0], [600, 400, 1], [1000, 400, 0]])
  43. test_event = np.array([[0, 0, 0], [100, 0, 0], [600, 0, 1], [700, 0, 1], [1000, 0, 0], [1100, 0, 0]])
  44. ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[1], rest_trial_ind=[0])
  45. self.assertTrue(np.allclose(ret, gt))