12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import os
- import training
- import unittest
- import joblib
- from glob import glob
- from sklearn.linear_model import LogisticRegression
- from dataloaders import neo
- import bci_core.utils as bci_utils
- from bci_core.feature_extractors import FeatExtractor
- from bci_core.pipeline import riemann_model_builder, baseline_model_builder
- import shutil
- from sklearn.utils.validation import check_is_fitted
- from sklearn.pipeline import Pipeline
- class TestTraining(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- root_path = './tests/data'
- sessions = {'flex': ['1', '2']}
- raw, cls.event_id = neo.raw_loader(root_path, sessions)
- cls.raw = raw
-
- def test_training_baseline(self):
- model = training.train_model(self.raw, self.event_id, model_type='baseline')
- check_is_fitted(model[-1])
-
- def test_training_riemann(self):
- model = training.train_model(self.raw, self.event_id, model_type='riemann')
- check_is_fitted(model[-1])
- def test_saver(self):
- feat_ext, embedder = riemann_model_builder(1000, 9, lf_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)])
- clf = LogisticRegression()
- event_id = {'1': 5, '0': 3}
- bci_utils.model_saver([feat_ext, embedder, clf], './tests/data', 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
- self.assertTrue(os.path.isdir(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66')))
- model_file = glob(os.path.join('./tests/data',
- 'f77cbe10a8de473992542e9f4e913a66',
- '*.pkl'))
-
- self.assertEqual(len(model_file), 1)
- name = os.path.normpath(model_file[0]).split(os.sep)
- class_name, events, date = name[-1].split('_')
- print(class_name, events, date)
- self.assertTrue(class_name == 'baseline')
- self.assertTrue(events == '0+1')
- # load model
- feat, embedder, model_base = joblib.load(model_file[0])
- self.assertTrue(isinstance(feat, FeatExtractor))
- self.assertTrue(isinstance(embedder, Pipeline))
- self.assertTrue(isinstance(model_base, LogisticRegression))
- shutil.rmtree(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66'))
|