online.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import joblib
  2. import numpy as np
  3. import random
  4. import logging
  5. import os
  6. from scipy import signal
  7. from .utils import parse_model_type
  8. logger = logging.getLogger(__name__)
  9. class Controller:
  10. """在线控制接口
  11. 运行时主要调用decision方法,
  12. 每次气动手反馈后调用reset_buffer方法,用以跳过气动手不应期
  13. Args:
  14. virtual_feedback_rate (float): 0-1之间浮点数,控制假反馈占比
  15. model_path (string): 模型文件路径
  16. buffer_steps (int):
  17. """
  18. def __init__(self,
  19. virtual_feedback_rate=1.,
  20. real_feedback_model=None):
  21. self.real_feedback_model = real_feedback_model
  22. self.virtual_feedback_rate = virtual_feedback_rate
  23. def step_decision(self, data, true_label=None):
  24. """抓握训练调用接口,只进行单次判决,不涉及马尔可夫过程,
  25. 假反馈的错误反馈默认输出为10000
  26. Args:
  27. data (mne.io.RawArray): 数据
  28. true_label (None or int): 训练时假反馈的真实标签
  29. Return:
  30. int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
  31. """
  32. virtual_feedback = self.virtual_feedback(true_label)
  33. logger.debug('step_decision: virtual feedback: {}'.format(virtual_feedback))
  34. if virtual_feedback is not None:
  35. return virtual_feedback
  36. if self.real_feedback_model is not None:
  37. fs, data = self.real_feedback_model.parse_data(data)
  38. p = self.real_feedback_model.step_probability(fs, data)
  39. logger.debug('step_decison: model probability: {}'.format(str(p)))
  40. pred = np.argmax(p)
  41. real_decision = self.real_feedback_model.model.classes_[pred]
  42. return real_decision
  43. else:
  44. raise ValueError('Neither decision model nor true label are given')
  45. def decision(self, data, true_label=None):
  46. """决策主要方法,输出逻辑如下:
  47. 如果有决策模型,无论是否有true_label,都会使用模型进行一步决策计算并填入buffer(不一定返回)
  48. 如果有true_label(训练模式),产生一个随机数确定本trial是否为假反馈,
  49. 是假反馈,产生一个随机数确定本trial产生正确or错误的假反馈,假反馈的标签为10000
  50. 不是假反馈,使用模型决策
  51. 如果没有true_label(测试模式),直接使用模型决策
  52. 模型决策逻辑:
  53. 根据模型记录的last_state,
  54. 如果当前state和last_state相同,输出-1
  55. 如果当前state和last_state不同,输出当前state
  56. Args:
  57. data (mne.io.RawArray): 数据
  58. true_label (None or int): 训练时假反馈的真实标签
  59. Return:
  60. int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
  61. """
  62. if self.real_feedback_model is not None:
  63. real_decision = self.real_feedback_model.viterbi(data)
  64. # map to unified label
  65. if real_decision != -1:
  66. real_decision = self.real_feedback_model.model.classes_[real_decision]
  67. virtual_feedback = self.virtual_feedback(true_label)
  68. if virtual_feedback is not None:
  69. return virtual_feedback
  70. # true_label is None or not running virtual feedback in this trial
  71. # if no real model, raise ValueError
  72. if self.real_feedback_model is None:
  73. raise ValueError('Neither decision model nor true label are given')
  74. return real_decision
  75. def virtual_feedback(self, true_label=None):
  76. if true_label is not None:
  77. p = random.random()
  78. if p < self.virtual_feedback_rate: # virtual feedback (error rate 0.2)
  79. p_correct = random.random()
  80. if p_correct < 0.8:
  81. return true_label
  82. else:
  83. return 10000
  84. return None
  85. class HMMModel:
  86. def __init__(self, transmat=None, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.5):
  87. self.n_classes = n_classes
  88. self._probability = np.zeros(n_classes)
  89. self.reset_state()
  90. self.state_change_threshold = state_change_threshold
  91. if transmat is None:
  92. # build state transition matrix
  93. self.state_trans_matrix = np.zeros((n_classes, n_classes))
  94. # fill diagonal
  95. np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
  96. # fill 0 -> each state,
  97. self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
  98. self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
  99. else:
  100. if isinstance(transmat, str):
  101. transmat = np.loadtxt(transmat)
  102. self.state_trans_matrix = transmat
  103. # emission probability moving average, (5 steps)
  104. self._filter_b = np.ones(5) / 5
  105. self._z = np.zeros((len(self._filter_b) - 1, n_classes))
  106. def reset_state(self):
  107. self._probability[0] = 1.
  108. self._last_state = 0
  109. def set_current_state(self, current_state):
  110. self._last_state = current_state
  111. self._probability = np.zeros(self.n_classes)
  112. self._probability[current_state] = 1
  113. def step_probability(self, fs, data):
  114. # do preprocessing here
  115. # common average
  116. data -= data.mean(axis=0)
  117. return data
  118. def parse_data(self, data):
  119. fs, event, data_array = data
  120. return fs, data_array
  121. def filter_prob(self, probs):
  122. """
  123. Args:
  124. probs (np.ndarray): (n_classes,)
  125. Returns:
  126. filtered_probs (np.ndarray): (n_classes,)
  127. """
  128. filtered_probs, self._z = signal.lfilter(self._filter_b, 1, probs[None], axis=0, zi=self._z)
  129. return filtered_probs.squeeze()
  130. def viterbi(self, data, return_step_p=False):
  131. """
  132. Interface for class decision
  133. """
  134. fs, data = self.parse_data(data)
  135. p = self.step_probability(fs, data)
  136. # smooth p
  137. p = self.filter_prob(p)
  138. if return_step_p:
  139. return p, self.update_state(p)
  140. else:
  141. return self.update_state(p)
  142. def update_state(self, current_p):
  143. # veterbi algorithm
  144. self._probability = (self.state_trans_matrix * self._probability.T).sum(axis=1) * current_p
  145. # normalize
  146. self._probability /= np.sum(self._probability)
  147. logger.debug("viterbi probability, {}".format(str(self._probability)))
  148. current_state = np.argmax(self._probability)
  149. if current_state == self._last_state:
  150. return -1
  151. else:
  152. if self._probability[current_state] > self.state_change_threshold:
  153. self.set_current_state(current_state)
  154. return current_state
  155. else:
  156. return -1
  157. @property
  158. def probability(self):
  159. return self._probability.copy()
  160. class BaselineHMM(HMMModel):
  161. def __init__(self, model, **kwargs):
  162. if isinstance(model, str):
  163. model = joblib.load(model)
  164. self.feat_extractor, self.model = model
  165. super(BaselineHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
  166. def step_probability(self, fs, data):
  167. """Step
  168. """
  169. data = super(BaselineHMM, self).step_probability(fs, data)
  170. # filter data
  171. filter_bank_data = self.feat_extractor.transform(data)
  172. # downsampling
  173. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  174. filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True)
  175. filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True)
  176. # predict proba
  177. p = self.model.predict_proba(filter_bank_data[None]).squeeze()
  178. return p
  179. class RiemannHMM(HMMModel):
  180. def __init__(self, model, **kwargs):
  181. if isinstance(model, str):
  182. model = joblib.load(model)
  183. self.feat_extractor, self.scaler, self.cov, self.model = model
  184. super(RiemannHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
  185. def step_probability(self, fs, data):
  186. """Step
  187. """
  188. data = super(RiemannHMM, self).step_probability(fs, data)
  189. data = self.feat_extractor.transform(data)
  190. data = data[None] # pad trial dimension
  191. # scale data
  192. data = self.scaler.transform(data)
  193. # compute cov
  194. data = self.cov.transform(data)
  195. # predict proba
  196. p = self.model.predict_proba(data).squeeze()
  197. return p
  198. def model_loader(model_path, **kwargs):
  199. """
  200. 模型如果存在训练好的transmat,会直接load
  201. """
  202. model_root, model_filename = os.path.dirname(model_path), os.path.basename(model_path)
  203. model_name = model_filename.split('.')[0]
  204. transmat_path = os.path.join(model_root, model_name + '_transmat.txt')
  205. if os.path.isfile(transmat_path):
  206. transmat = np.loadtxt(transmat_path)
  207. else:
  208. transmat = None
  209. kwargs['transmat'] = transmat
  210. model_type, _ = parse_model_type(model_path)
  211. if model_type == 'baseline':
  212. return BaselineHMM(model_path, **kwargs)
  213. elif model_type == 'riemann':
  214. return RiemannHMM(model_path, **kwargs)
  215. else:
  216. raise ValueError(f'Unexpected model type: {model_type}, expect "baseline" or "riemann"')