2
0

test_training.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. import training
  3. import unittest
  4. import joblib
  5. from glob import glob
  6. from sklearn.linear_model import LogisticRegression
  7. from dataloaders import neo
  8. import bci_core.utils as bci_utils
  9. from bci_core.feature_extractors import FeatExtractor
  10. from bci_core.pipeline import riemann_model_builder, baseline_model_builder
  11. import shutil
  12. from sklearn.utils.validation import check_is_fitted
  13. from sklearn.pipeline import Pipeline
  14. class TestTraining(unittest.TestCase):
  15. @classmethod
  16. def setUpClass(cls) -> None:
  17. root_path = './tests/data'
  18. sessions = {'flex': ['1', '2']}
  19. raw, cls.event_id = neo.raw_loader(root_path, sessions)
  20. cls.raw = raw
  21. def test_training_baseline(self):
  22. model = training.train_model(self.raw, self.event_id, model_type='baseline')
  23. check_is_fitted(model[-1])
  24. def test_training_riemann(self):
  25. model = training.train_model(self.raw, self.event_id, model_type='riemann')
  26. check_is_fitted(model[-1])
  27. def test_saver(self):
  28. feat_ext, embedder = riemann_model_builder(1000, 9, lf_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)])
  29. clf = LogisticRegression()
  30. event_id = {'1': 5, '0': 3}
  31. bci_utils.model_saver([feat_ext, embedder, clf], './tests/data', 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
  32. self.assertTrue(os.path.isdir(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66')))
  33. model_file = glob(os.path.join('./tests/data',
  34. 'f77cbe10a8de473992542e9f4e913a66',
  35. '*.pkl'))
  36. self.assertEqual(len(model_file), 1)
  37. name = os.path.normpath(model_file[0]).split(os.sep)
  38. class_name, events, date = name[-1].split('_')
  39. print(class_name, events, date)
  40. self.assertTrue(class_name == 'baseline')
  41. self.assertTrue(events == '0+1')
  42. # load model
  43. feat, embedder, model_base = joblib.load(model_file[0])
  44. self.assertTrue(isinstance(feat, FeatExtractor))
  45. self.assertTrue(isinstance(embedder, Pipeline))
  46. self.assertTrue(isinstance(model_base, LogisticRegression))
  47. shutil.rmtree(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66'))