test_validation.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import unittest
  2. import os
  3. import numpy as np
  4. from glob import glob
  5. import shutil
  6. import mne
  7. from bci_core import utils as ana_utils
  8. from bci_core.online import model_loader
  9. from training import train_model
  10. from dataloaders import neo
  11. from online_sim import simulation, _construct_model_event
  12. from validation import val_by_epochs
  13. class TestOnlineSim(unittest.TestCase):
  14. @classmethod
  15. def setUpClass(cls):
  16. root_path = './tests/data'
  17. raw_train, cls.event_id = neo.raw_loader(root_path, {'flex': ['1']}, reref_method='bipolar')
  18. cls.raw_val, _ = neo.raw_loader(root_path, {'flex': ['2']},
  19. upsampled_epoch_length=None,
  20. reref_method='bipolar')
  21. # train with the first half
  22. model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')
  23. ana_utils.model_saver(model, './tests/data/', 'baseline', 'test', cls.event_id)
  24. cls.model_path = glob(os.path.join('./tests/data/', 'test', '*.pkl'))[0]
  25. @classmethod
  26. def tearDownClass(cls) -> None:
  27. shutil.rmtree(os.path.join('./tests/data/', 'test'))
  28. return super().tearDownClass()
  29. def test_event_metric(self):
  30. event_gt = np.array([[0, 0, 0], [5, 0, 1], [7, 0, 0], [9, 0, 2]])
  31. event_pred = np.array([[1, 0, 0], [4, 0, 1], [6, 0, 1], [7, 0, 0], [10, 0, 1], [11, 0, 2]])
  32. fs = 1
  33. precision, recall, f1_score = ana_utils.event_metric(event_gt, event_pred, fs, ignore_event=(0,))
  34. self.assertEqual(f1_score, 2 / 3)
  35. self.assertEqual(precision, 1 / 2)
  36. self.assertEqual(recall, 1)
  37. def test_construct_event(self):
  38. seq_1 = [(1, -1), (2, -1), (3, -1), (4, 1)]
  39. seq_2 = [(1, 0), (2, 0), (4, 1)]
  40. gt = [[1, 0, 0], [4, 0, 1]]
  41. ret_ = _construct_model_event(seq_1, 1, start_cond=0)
  42. self.assertTrue(np.allclose(gt, ret_))
  43. ret_ = _construct_model_event(seq_2, 1, start_cond=0)
  44. self.assertTrue(np.allclose(gt, ret_))
  45. def test_sim(self):
  46. model = model_loader(self.model_path,
  47. state_change_threshold=0.7,
  48. state_trans_prob=0.7)
  49. metric_hmm, metric_nohmm, figs = simulation(self.raw_val, self.event_id, model=model, epoch_length=1., step_length=0.1)
  50. figs[0].savefig('./tests/data/pred_hmm.pdf')
  51. figs[1].savefig('./tests/data/pred_naive.pdf')
  52. self.assertTrue(metric_hmm[-2] > 0.7) # f1-score (with hmm)
  53. self.assertTrue(metric_nohmm[-2] < 0.4) # f1-score (without hmm)
  54. def test_val_model(self):
  55. metrices, fig_conf = val_by_epochs(self.raw_val, self.model_path, self.event_id, 1.)
  56. fig_conf.savefig('./tests/data/conf.pdf')
  57. self.assertGreater(metrices[0], 0.85)
  58. self.assertGreater(metrices[1], 0.7)
  59. self.assertGreater(metrices[2], 0.7)
  60. if __name__ == '__main__':
  61. unittest.main()