test_validation.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import unittest
  2. import os
  3. import numpy as np
  4. from glob import glob
  5. import shutil
  6. from bci_core import utils as ana_utils
  7. from training import train_model, model_saver
  8. from dataloaders import neo
  9. from online_sim import simulation
  10. from validation import val_by_epochs
  11. class TestOnlineSim(unittest.TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. root_path = './tests/data'
  15. raw, cls.event_id = neo.raw_preprocessing(root_path, {'flex': ['1', '2']})
  16. cls.raw = raw
  17. # split into 2 pieces
  18. t_min, t_max = raw.times[0], raw.times[-1]
  19. t_mid = raw.times[len(raw.times) // 2]
  20. raw_train = raw.copy().crop(tmin=t_min, tmax=t_mid, include_tmax=True)
  21. cls.raw_val = raw.copy().crop(tmin=t_mid, tmax=t_max)
  22. # reconstruct single event for validation
  23. if cls.raw_val.annotations.onset[0] > t_mid:
  24. # correct time by first timestamp
  25. cls.raw_val.annotations.onset -= t_mid
  26. # train with the first half
  27. model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')
  28. model_saver(model, './tests/data/', 'baseline', 'test', cls.event_id)
  29. cls.model_path = glob(os.path.join('./tests/data/', 'test', '*.pkl'))[0]
  30. @classmethod
  31. def tearDownClass(cls) -> None:
  32. shutil.rmtree(os.path.join('./tests/data/', 'test'))
  33. return super().tearDownClass()
  34. def test_event_metric(self):
  35. event_gt = np.array([[0, 0, 0], [5, 0, 1], [7, 0, 0], [9, 0, 2]])
  36. event_pred = np.array([[1, 0, 0], [4, 0, 1], [6, 0, 1], [7, 0, 0], [10, 0, 1], [11, 0, 2]])
  37. fs = 1
  38. precision, recall, f1_score = ana_utils.event_metric(event_gt, event_pred, fs, ignore_event=(0,))
  39. self.assertEqual(f1_score, 2 / 3)
  40. self.assertEqual(precision, 1 / 2)
  41. self.assertEqual(recall, 1)
  42. def test_sim(self):
  43. metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
  44. fig_pred.savefig('./tests/data/pred.pdf')
  45. self.assertTrue(metric_hmm[-2] > 0.3) # f1-score (with hmm)
  46. self.assertTrue(metric_nohmm[-2] < 0.15) # f1-score (without hmm)
  47. def test_val_model(self):
  48. metrices, fig_conf = val_by_epochs(self.raw_val, self.model_path, self.event_id, 1.)
  49. fig_conf.savefig('./tests/data/conf.pdf')
  50. self.assertGreater(metrices[0], 0.85)
  51. self.assertGreater(metrices[1], 0.7)
  52. self.assertGreater(metrices[2], 0.7)
  53. if __name__ == '__main__':
  54. unittest.main()