test_online.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. import shutil
  3. import random
  4. import bci_core.online as online
  5. import training
  6. from dataloaders import library_ieeg
  7. from validation import DataGenerator
  8. import unittest
  9. import numpy as np
  10. from glob import glob
  11. class TestOnline(unittest.TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. root_path = './tests/data'
  15. event_id = {'ball': 2, 'rest': 0}
  16. raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
  17. raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
  18. model = training.train_model(raw, event_id, model_type='baseline')
  19. training.model_saver(model, root_path, 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
  20. cls.model_root = os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66')
  21. cls.model_path = glob(os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66', '*.pkl'))[0]
  22. cls.data_gen = DataGenerator(raw.info['sfreq'], raw.get_data())
  23. @classmethod
  24. def tearDownClass(cls) -> None:
  25. shutil.rmtree(cls.model_root)
  26. return super().tearDownClass()
  27. def test_step_feedback(self):
  28. controller = online.Controller(0, self.model_path)
  29. rets = []
  30. for time, data in self.data_gen.loop():
  31. cls = controller.step_decision(data)
  32. rets.append(cls)
  33. self.assertTrue(np.allclose(np.unique(rets), [0, 2]))
  34. def test_virtual_feedback(self):
  35. controller = online.Controller(1, None)
  36. n_trial = 1000
  37. correct = 0
  38. for _ in range(n_trial):
  39. label = random.randint(0, 1)
  40. ret = controller.decision(None, label)
  41. if ret == label:
  42. correct += 1
  43. self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
  44. correct = 0
  45. for _ in range(n_trial):
  46. label = random.randint(0, 1)
  47. ret = controller.step_decision(None, label)
  48. if ret == label:
  49. correct += 1
  50. self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
  51. def test_real_feedback(self):
  52. controller = online.Controller(0, self.model_path)
  53. rets = []
  54. for i, (time, data) in zip(range(300), self.data_gen.loop()):
  55. cls = controller.decision(data)
  56. rets.append(cls)
  57. self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 2]))
  58. class TestHMM(unittest.TestCase):
  59. def test_state_transfer(self):
  60. # binary
  61. probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
  62. true_state = [-1, -1, 1, -1, -1, -1, 0]
  63. model = online.HMMModel(2, state_trans_prob=0.9, state_change_threshold=0.7)
  64. states = []
  65. for p in probs:
  66. cur_state = model.update_state(p)
  67. states.append(cur_state)
  68. self.assertTrue(np.allclose(states, true_state))
  69. # triple
  70. probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
  71. true_state = [-1, 1, -1, -1, 0, 2]
  72. model = online.HMMModel(3, state_trans_prob=0.9)
  73. states = []
  74. for p in probs:
  75. cur_state = model.update_state(p)
  76. states.append(cur_state)
  77. self.assertTrue(np.allclose(states, true_state))