online.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import joblib
  2. import numpy as np
  3. import random
  4. import logging
  5. from scipy import signal
  6. import mne
  7. from .feature_extractors import filterbank_extractor
  8. from .utils import parse_model_type
  9. class Controller:
  10. """在线控制接口
  11. 运行时主要调用decision方法,
  12. 每次气动手反馈后调用reset_buffer方法,用以跳过气动手不应期
  13. Args:
  14. virtual_feedback_rate (float): 0-1之间浮点数,控制假反馈占比
  15. model_path (string): 模型文件路径
  16. buffer_steps (int):
  17. """
  18. def __init__(self,
  19. virtual_feedback_rate=1.,
  20. model_path=None,
  21. state_change_threshold=0.6):
  22. if (model_path is None) or (model_path == 'None'):
  23. self.real_feedback_model = None
  24. else:
  25. self.model_type, _ = parse_model_type(model_path)
  26. if self.model_type == 'baseline':
  27. self.real_feedback_model = BaselineHMM(model_path, state_change_threshold=state_change_threshold)
  28. else:
  29. raise NotImplementedError
  30. self.virtual_feedback_rate = virtual_feedback_rate
  31. def step_decision(self, data, true_label=None):
  32. """抓握训练调用接口,只进行单次判决,不涉及马尔可夫过程,
  33. 假反馈的错误反馈默认输出为10000
  34. Args:
  35. data (mne.io.RawArray): 数据
  36. true_label (None or int): 训练时假反馈的真实标签
  37. Return:
  38. int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
  39. """
  40. virtual_feedback = self.virtual_feedback(true_label)
  41. if virtual_feedback is not None:
  42. return virtual_feedback
  43. if self.real_feedback_model is not None:
  44. fs, data = self.real_feedback_model.parse_data(data)
  45. p = self.real_feedback_model.step_probability(fs, data)
  46. pred = np.argmax(p)
  47. real_decision = self.real_feedback_model.model.classes_[pred]
  48. return real_decision
  49. else:
  50. raise ValueError('Neither decision model nor true label are given')
  51. def decision(self, data, true_label=None):
  52. """决策主要方法,输出逻辑如下:
  53. 如果有决策模型,无论是否有true_label,都会使用模型进行一步决策计算并填入buffer(不一定返回)
  54. 如果有true_label(训练模式),产生一个随机数确定本trial是否为假反馈,
  55. 是假反馈,产生一个随机数确定本trial产生正确or错误的假反馈,假反馈的标签为10000
  56. 不是假反馈,使用模型决策
  57. 如果没有true_label(测试模式),直接使用模型决策
  58. 模型决策逻辑:
  59. 根据模型记录的last_state,
  60. 如果当前state和last_state相同,输出-1
  61. 如果当前state和last_state不同,输出当前state
  62. Args:
  63. data (mne.io.RawArray): 数据
  64. true_label (None or int): 训练时假反馈的真实标签
  65. Return:
  66. int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
  67. """
  68. if self.real_feedback_model is not None:
  69. real_decision = self.real_feedback_model.verterbi(data)
  70. # map to unified label
  71. if real_decision != -1:
  72. real_decision = self.real_feedback_model.model.classes_[real_decision]
  73. virtual_feedback = self.virtual_feedback(true_label)
  74. if virtual_feedback is not None:
  75. return virtual_feedback
  76. # true_label is None or not running virtual feedback in this trial
  77. # if no real model, raise ValueError
  78. if self.real_feedback_model is None:
  79. raise ValueError('Neither decision model nor true label are given')
  80. return real_decision
  81. def virtual_feedback(self, true_label=None):
  82. if true_label is not None:
  83. p = random.random()
  84. if p < self.virtual_feedback_rate: # virtual feedback (error rate 0.2)
  85. p_correct = random.random()
  86. if p_correct < 0.8:
  87. return true_label
  88. else:
  89. return 10000
  90. return None
  91. def reset_buffer(self):
  92. # call after every real feedback
  93. if self.real_feedback_model is not None:
  94. self.real_feedback_model.reset_prob()
  95. class HMMModel:
  96. def __init__(self, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.7):
  97. self.n_classes = n_classes
  98. self._probability = np.ones(n_classes) / n_classes
  99. self._last_state = 0
  100. self.state_change_threshold = state_change_threshold
  101. # TODO: train with daily use data
  102. # build state transition matrix
  103. self.state_trans_matrix = np.zeros((n_classes, n_classes))
  104. # fill diagonal
  105. np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
  106. # fill 0 -> each state,
  107. self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
  108. self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
  109. def reset_state(self):
  110. self._last_state = 0
  111. self._probability = np.ones(self.n_classes) / self.n_classes
  112. def set_current_state(self, current_state):
  113. self._last_state = current_state
  114. self._probability = np.zeros(self.n_classes)
  115. self._probability[current_state] = 1
  116. def step_probability(self, fs, data):
  117. # do preprocessing here
  118. # common average
  119. data -= data.mean(axis=0)
  120. return data
  121. def parse_data(self, data):
  122. fs, event, data_array = data
  123. return fs, data_array
  124. def verterbi(self, data):
  125. """
  126. Interface for class decision
  127. """
  128. fs, data = self.parse_data(data)
  129. p = self.step_probability(fs, data)
  130. logging.debug(p, self.probability)
  131. return self.update_state(p)
  132. def update_state(self, current_p):
  133. # veterbi algorithm
  134. self._probability = (self.state_trans_matrix * self._probability.T).sum(axis=1) * current_p
  135. # normalize
  136. self._probability /= np.sum(self._probability)
  137. current_state = np.argmax(self._probability)
  138. if current_state == self._last_state:
  139. return -1
  140. else:
  141. if self._probability[current_state] > self.state_change_threshold:
  142. self.set_current_state(current_state)
  143. return current_state
  144. else:
  145. return -1
  146. @property
  147. def probability(self):
  148. return self._probability[self._last_state]
  149. class BaselineHMM(HMMModel):
  150. def __init__(self, model, **kwargs):
  151. if isinstance(model, str):
  152. self.model = joblib.load(model)
  153. else:
  154. self.model = model
  155. self.freqs = np.arange(20, 150, 15)
  156. super(BaselineHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
  157. def step_probability(self, fs, data):
  158. """Step
  159. """
  160. data = super(BaselineHMM, self).step_probability(fs, data)
  161. # filter data
  162. filter_bank_data = filterbank_extractor(data, fs, self.freqs, reshape_freqs_dim=True)
  163. # downsampling
  164. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  165. filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True)
  166. filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True)
  167. # predict proba
  168. p = self.model.predict_proba(filter_bank_data[None]).squeeze()
  169. return p
  170. class RiemannHMM(HMMModel):
  171. def __init__(self, model, **kwargs):
  172. if isinstance(model, str):
  173. self.feat_extractor, self.scaler, self.cov, self.model = joblib.load(model)
  174. else:
  175. self.feat_extractor, self.scaler, self.cov, self.model = model
  176. super(RiemannHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
  177. def step_probability(self, fs, data):
  178. """Step
  179. """
  180. data = super(RiemannHMM, self).step_probability(fs, data)
  181. data = self.feat_extractor.transform(data)
  182. # scale data
  183. data = self.scaler.transform(data)
  184. # compute cov
  185. data = self.cov.transform(data)
  186. # predict proba
  187. p = self.model.predict_proba(data[None]).squeeze()
  188. return p