test_training.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. import training
  3. import unittest
  4. import joblib
  5. from glob import glob
  6. from dataloaders import neo
  7. from bci_core.feature_extractors import FeatExtractor
  8. from bci_core.model import baseline_model, riemann_model, ChannelScaler
  9. import shutil
  10. from sklearn.utils.validation import check_is_fitted
  11. from sklearn.pipeline import Pipeline
  12. class TestTraining(unittest.TestCase):
  13. @classmethod
  14. def setUpClass(cls) -> None:
  15. root_path = './tests/data'
  16. sessions = {'flex': ['1', '2']}
  17. raw, cls.event_id = neo.raw_loader(root_path, sessions)
  18. cls.raw = raw
  19. def test_training_baseline(self):
  20. model = training.train_model(self.raw, self.event_id, model_type='baseline')
  21. check_is_fitted(model[1])
  22. def test_saver(self):
  23. feat_ext = FeatExtractor(1000, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)])
  24. model_riemann = riemann_model(1)
  25. model_baseline = baseline_model(1)
  26. scaler = ChannelScaler()
  27. event_id = {'1': 5, '0': 3}
  28. training.model_saver([feat_ext, scaler, model_riemann, model_baseline], './tests/data', 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
  29. self.assertTrue(os.path.isdir(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66')))
  30. model_file = glob(os.path.join('./tests/data',
  31. 'f77cbe10a8de473992542e9f4e913a66',
  32. '*.pkl'))
  33. self.assertEqual(len(model_file), 1)
  34. name = os.path.normpath(model_file[0]).split(os.sep)
  35. class_name, events, date = name[-1].split('_')
  36. print(class_name, events, date)
  37. self.assertTrue(class_name == 'baseline')
  38. self.assertTrue(events == '0+1')
  39. # load model
  40. feat, scaler, model_riem, model_base = joblib.load(model_file[0])
  41. self.assertTrue(isinstance(feat, FeatExtractor))
  42. self.assertTrue(isinstance(scaler, ChannelScaler))
  43. self.assertTrue(isinstance(model_riem, Pipeline))
  44. self.assertTrue(isinstance(model_base, Pipeline))
  45. shutil.rmtree(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66'))