test_neoloader.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import unittest
  2. from dataloaders import neo
  3. import mne
  4. import numpy as np
  5. class TestDataloader(unittest.TestCase):
  6. def test_load_sample_data(self):
  7. root_path = './tests/data'
  8. sessions = {'flex': ['1']}
  9. raw, event_id = neo.raw_loader(root_path, sessions, upsampled_epoch_length=1.)
  10. events, event_id = mne.events_from_annotations(raw, event_id=event_id)
  11. events, events_cnt = np.unique(events[:, -1], return_counts=True)
  12. self.assertTrue(np.allclose(events_cnt, (75, 75)))
  13. def test_load_session(self):
  14. root_path = './tests/data'
  15. sessions = {'flex': ['1', '3'], 'ball': ['2']}
  16. raws = neo.load_sessions(root_path, sessions)
  17. # test if interleaved
  18. sess_f = tuple(f for f, r in raws)
  19. self.assertEqual(len(raws), 3)
  20. self.assertTupleEqual(sess_f, ('flex', 'ball', 'flex'))
  21. def test_event_parser(self):
  22. # fixed length
  23. fs = 100
  24. test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
  25. gt = np.array([[0, 400, 4], [600, 400, 3], [1000, 400, 4]])
  26. ret = neo.reconstruct_events(test_event, fs, trial_duration=4)
  27. self.assertTrue(np.allclose(ret, gt))
  28. # duration as dict
  29. gt = np.array([[0, 400, 4], [600, 200, 3], [1000, 400, 4]])
  30. trial_duration = {4: 4., 3: 2.}
  31. ret = neo.reconstruct_events(test_event, fs, trial_duration=trial_duration)
  32. self.assertTrue(np.allclose(ret, gt))
  33. # varing length
  34. gt = np.array([[0, 600, 4], [600, 400, 3], [1000, 100, 4]])
  35. ret = neo.reconstruct_events(test_event, fs, trial_duration=None)
  36. self.assertTrue(np.allclose(ret, gt))
  37. # use ori
  38. test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
  39. gt = np.array([[0, 400, 4], [100, 400, 4], [600, 400, 3], [700, 400, 3], [1000, 400, 4], [1100, 400, 4]])
  40. ret = neo.reconstruct_events(test_event, fs, trial_duration=4, use_ori_events=True)
  41. self.assertTrue(np.allclose(ret, gt))