1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import os
- import shutil
- import random
- import bci_core.online as online
- import training
- from dataloaders import neo
- from online_sim import DataGenerator
- import unittest
- import numpy as np
- from glob import glob
- class TestOnline(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- root_path = './tests/data'
- raw, event_id = neo.raw_loader(root_path, {'flex': ['1']})
-
- model = training.train_model(raw, event_id, model_type='baseline')
-
- training.model_saver(model, root_path, 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
- cls.model_root = os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66')
- cls.model_path = glob(os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66', '*.pkl'))[0]
- cls.data_gen = DataGenerator(raw.info['sfreq'], raw.get_data())
-
- @classmethod
- def tearDownClass(cls) -> None:
- shutil.rmtree(cls.model_root)
- return super().tearDownClass()
-
- def test_step_feedback(self):
- model_hmm = online.model_loader(self.model_path)
- controller = online.Controller(0, model_hmm)
- rets = []
- for time, data in self.data_gen.loop():
- cls = controller.step_decision(data)
- rets.append(cls)
- self.assertTrue(np.allclose(np.unique(rets), [0, 3]))
-
- def test_virtual_feedback(self):
- controller = online.Controller(1, None)
-
- n_trial = 1000
- correct = 0
- for _ in range(n_trial):
- label = random.randint(0, 1)
- ret = controller.decision(None, label)
- if ret == label:
- correct += 1
- self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
- correct = 0
- for _ in range(n_trial):
- label = random.randint(0, 1)
- ret = controller.step_decision(None, label)
- if ret == label:
- correct += 1
- self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
- def test_real_feedback(self):
- model_hmm = online.model_loader(self.model_path)
- controller = online.Controller(0, model_hmm)
- rets = []
- for i, (time, data) in zip(range(300), self.data_gen.loop()):
- cls = controller.decision(data)
- rets.append(cls)
- self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 3]))
- class TestHMM(unittest.TestCase):
- def test_state_transfer(self):
- # binary
- probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
- true_state = [-1, -1, 1, -1, -1, -1, 0]
- model = online.HMMModel(transmat=None, n_classes=2, state_trans_prob=0.9, state_change_threshold=0.7)
- states = []
- for p in probs:
- cur_state = model.update_state(p)
- states.append(cur_state)
- self.assertTrue(np.allclose(states, true_state))
- # triple
- probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
- true_state = [-1, 1, -1, -1, 0, 2]
- model = online.HMMModel(transmat=None, n_classes=3, state_trans_prob=0.9)
- states = []
- for p in probs:
- cur_state = model.update_state(p)
- states.append(cur_state)
- self.assertTrue(np.allclose(states, true_state))
|