1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import unittest
- import os
- import numpy as np
- from glob import glob
- import shutil
- from bci_core import utils as ana_utils
- from training import train_model, model_saver
- from dataloaders import library_ieeg
- from validation import validation
- class TestValidation(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- root_path = './tests/data'
- cls.event_id = {'ball': 2, 'rest': 0}
- raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
- raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
- cls.raw = raw
- # split into 2 pieces
- t_min, t_max = raw.times[0], raw.times[-1]
- t_mid = raw.times[len(raw.times) // 2]
- raw_train = raw.copy().crop(tmin=t_min, tmax=t_mid, include_tmax=True)
- cls.raw_val = raw.copy().crop(tmin=t_mid, tmax=t_max)
- # reconstruct single event for validation
- if cls.raw_val.annotations.onset[0] > t_mid:
- # correct time by first timestamp
- cls.raw_val.annotations.onset -= t_mid
-
- # train with the first half
- model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')
- model_saver(model, './tests/data/', 'baseline', 'test', cls.event_id)
- cls.model_path = glob(os.path.join('./tests/data/', 'test', '*.pkl'))[0]
-
- @classmethod
- def tearDownClass(cls) -> None:
- shutil.rmtree(os.path.join('./tests/data/', 'test'))
- return super().tearDownClass()
-
- def test_event_metric(self):
- event_gt = np.array([[0, 0, 0], [5, 0, 1], [7, 0, 0], [9, 0, 2]])
- event_pred = np.array([[1, 0, 0], [4, 0, 1], [6, 0, 1], [7, 0, 0], [10, 0, 1], [11, 0, 2]])
- fs = 1
- precision, recall, f1_score = ana_utils.event_metric(event_gt, event_pred, fs, ignore_event=(0,))
- self.assertEqual(f1_score, 2 / 3)
- self.assertEqual(precision, 1 / 2)
- self.assertEqual(recall, 1)
- def test_validation(self):
- (precision, recall, f1_score, r), fig_erds, fig_pred = validation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
- fig_erds.savefig('./tests/data/erds.pdf')
- fig_pred.savefig('./tests/data/pred.pdf')
- self.assertTrue(f1_score > 0.9)
- self.assertTrue(r > 0.5)
- if __name__ == '__main__':
- unittest.main()
|