import unittest import os import numpy as np from bci_core import utils as ana_utils from training import train_model from dataloaders import library_ieeg from validation import validation class TestValidation(unittest.TestCase): @classmethod def setUpClass(cls): root_path = './tests/data' cls.event_id = {'ball': 2, 'rest': 0} raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball') raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]]) cls.raw = raw # split into 2 pieces t_min, t_max = raw.times[0], raw.times[-1] t_mid = raw.times[len(raw.times) // 2] raw_train = raw.copy().crop(tmin=t_min, tmax=t_mid, include_tmax=True) cls.raw_val = raw.copy().crop(tmin=t_mid, tmax=t_max) # reconstruct single event for validation if cls.raw_val.annotations.onset[0] > t_mid: # correct time by first timestamp cls.raw_val.annotations.onset -= t_mid # train with the first half cls.model = train_model(raw_train, event_id=cls.event_id, model_type='baseline') def test_event_metric(self): event_gt = np.array([[0, 0, 0], [5, 0, 1], [7, 0, 0], [9, 0, 2]]) event_pred = np.array([[1, 0, 0], [4, 0, 1], [6, 0, 1], [7, 0, 0], [10, 0, 1], [11, 0, 2]]) fs = 1 precision, recall, f1_score = ana_utils.event_metric(event_gt, event_pred, fs, ignore_event=(0,)) self.assertEqual(f1_score, 2 / 3) self.assertEqual(precision, 1 / 2) self.assertEqual(recall, 1) def test_validation(self): (precision, recall, f1_score, r), fig_erds, fig_pred = validation(self.raw, 'baseline', self.event_id, model=self.model, state_change_threshold=0.7) fig_erds.savefig('./tests/data/erds.pdf') fig_pred.savefig('./tests/data/pred.pdf') self.assertTrue(f1_score > 0.9) self.assertTrue(r > 0.5) if __name__ == '__main__': unittest.main()