test_online.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. import shutil
  3. import random
  4. import bci_core.online as online
  5. from bci_core.utils import model_saver
  6. import training
  7. from dataloaders import neo
  8. from online_sim import DataGenerator
  9. import unittest
  10. import numpy as np
  11. from glob import glob
  12. class TestOnline(unittest.TestCase):
  13. @classmethod
  14. def setUpClass(cls):
  15. root_path = './tests/data'
  16. raw, event_id = neo.raw_loader(root_path, {'flex': ['1']}, reref_method='bipolar')
  17. model = training.train_model(raw, event_id, model_type='baseline')
  18. model_saver(model, root_path, 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
  19. cls.model_root = os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66')
  20. cls.model_path = glob(os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66', '*.pkl'))[0]
  21. raw, event_id = neo.raw_loader(root_path, {'flex': ['1']}, reref_method='monopolar')
  22. cls.data_gen = DataGenerator(raw.info['sfreq'], raw.get_data())
  23. @classmethod
  24. def tearDownClass(cls) -> None:
  25. shutil.rmtree(cls.model_root)
  26. return super().tearDownClass()
  27. def test_step_feedback(self):
  28. model_hmm = online.model_loader(self.model_path)
  29. controller = online.Controller(0, model_hmm, reref_method='bipolar')
  30. rets = []
  31. for time, data in self.data_gen.loop():
  32. cls = controller.step_decision(data)
  33. rets.append(cls)
  34. self.assertTrue(np.allclose(np.unique(rets), [0, 3]))
  35. def test_virtual_feedback(self):
  36. controller = online.Controller(1, None)
  37. n_trial = 1000
  38. correct = 0
  39. for _ in range(n_trial):
  40. label = random.randint(0, 1)
  41. ret = controller.decision(None, label)
  42. if ret == label:
  43. correct += 1
  44. self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
  45. correct = 0
  46. for _ in range(n_trial):
  47. label = random.randint(0, 1)
  48. ret = controller.step_decision(None, label)
  49. if ret == label:
  50. correct += 1
  51. self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
  52. def test_real_feedback(self):
  53. model_hmm = online.model_loader(self.model_path)
  54. controller = online.Controller(0, model_hmm, reref_method='bipolar')
  55. rets = []
  56. for i, (time, data) in zip(range(300), self.data_gen.loop()):
  57. cls = controller.decision(data)
  58. rets.append(cls)
  59. self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 3]))
  60. class TestHMM(unittest.TestCase):
  61. def test_state_transfer(self):
  62. # binary
  63. probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
  64. true_state = [-1, -1, 1, -1, -1, -1, 0]
  65. model = online.HMMModel(transmat=None, n_classes=2, state_trans_prob=0.9, state_change_threshold=0.7, momentum=0.)
  66. states = []
  67. for p in probs:
  68. cur_state = model.update_state(p)
  69. states.append(cur_state)
  70. self.assertTrue(np.allclose(states, true_state))
  71. # triple
  72. probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
  73. true_state = [-1, 1, -1, -1, 0, 2]
  74. model = online.HMMModel(transmat=None, n_classes=3, state_trans_prob=0.9, momentum=0.)
  75. states = []
  76. for p in probs:
  77. cur_state = model.update_state(p)
  78. states.append(cur_state)
  79. self.assertTrue(np.allclose(states, true_state))
  80. def test_momentum(self):
  81. # binary
  82. probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.01, 0.99], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
  83. true_state = [-1, -1, -1, -1, 1, -1, -1]
  84. model = online.HMMModel(transmat=None, n_classes=2, state_trans_prob=0.9, state_change_threshold=0.7, momentum=0.5)
  85. states = []
  86. for p in probs:
  87. cur_state = model.update_state(p)
  88. states.append(cur_state)
  89. print(states)
  90. self.assertTrue(np.allclose(states, true_state))