12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- import unittest
- from dataloaders import neo
- import mne
- import numpy as np
- class TestDataloader(unittest.TestCase):
- def test_load_sample_data(self):
- root_path = './tests/data'
- sessions = {'flex': ['1']}
- raw, event_id = neo.raw_loader(root_path, sessions, upsampled_epoch_length=1.)
- events, event_id = mne.events_from_annotations(raw, event_id=event_id)
- events, events_cnt = np.unique(events[:, -1], return_counts=True)
- self.assertTrue(np.allclose(events_cnt, (75, 75)))
- def test_load_session(self):
- root_path = './tests/data'
- sessions = {'flex': ['1', '3'], 'ball': ['2']}
- raws = neo.load_sessions(root_path, sessions)
- # test if interleaved
- sess_f = tuple(f for f, r in raws)
- self.assertEqual(len(raws), 3)
- self.assertTupleEqual(sess_f, ('flex', 'ball', 'flex'))
- def test_event_parser(self):
- # fixed length
- fs = 100
- test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
- gt = np.array([[0, 400, 4], [600, 400, 3], [1000, 400, 4]])
- ret = neo.reconstruct_events(test_event, fs, trial_duration=4)
- self.assertTrue(np.allclose(ret, gt))
- # duration as dict
- gt = np.array([[0, 400, 4], [600, 200, 3], [1000, 400, 4]])
- trial_duration = {4: 4., 3: 2.}
- ret = neo.reconstruct_events(test_event, fs, trial_duration=trial_duration)
- self.assertTrue(np.allclose(ret, gt))
- # varing length
- gt = np.array([[0, 600, 4], [600, 400, 3], [1000, 100, 4]])
- ret = neo.reconstruct_events(test_event, fs, trial_duration=None)
- self.assertTrue(np.allclose(ret, gt))
- # use ori
- test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
- gt = np.array([[0, 400, 4], [100, 400, 4], [600, 400, 3], [700, 400, 3], [1000, 400, 4], [1100, 400, 4]])
- ret = neo.reconstruct_events(test_event, fs, trial_duration=4, use_ori_events=True)
- self.assertTrue(np.allclose(ret, gt))
|