online.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import joblib
  2. import numpy as np
  3. import random
  4. import logging
  5. import os
  6. from scipy import signal
  7. from .utils import parse_model_type, reref
  8. from .pipeline import data_evaluation
  9. logger = logging.getLogger(__name__)
  10. class Controller:
  11. """在线控制接口
  12. 运行时主要调用decision方法,
  13. 每次气动手反馈后调用reset_buffer方法,用以跳过气动手不应期
  14. Args:
  15. virtual_feedback_rate (float): 0-1之间浮点数,控制假反馈占比
  16. model_path (string): 模型文件路径
  17. buffer_steps (int):
  18. """
  19. def __init__(self,
  20. virtual_feedback_rate=1.,
  21. real_feedback_model=None,
  22. reref_method='monopolar'):
  23. self.real_feedback_model = real_feedback_model
  24. self.virtual_feedback_rate = virtual_feedback_rate
  25. self.reref_method = reref_method
  26. def step_decision(self, data, true_label=None):
  27. """抓握训练调用接口,只进行单次判决,不涉及马尔可夫过程,
  28. 假反馈的错误反馈默认输出为10000
  29. Args:
  30. data (mne.io.RawArray): 数据
  31. true_label (None or int): 训练时假反馈的真实标签
  32. Return:
  33. int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
  34. """
  35. virtual_feedback = self.virtual_feedback(true_label)
  36. logger.debug('step_decision: virtual feedback: {}'.format(virtual_feedback))
  37. if virtual_feedback is not None:
  38. return virtual_feedback
  39. if self.real_feedback_model is not None:
  40. fs, data = self.parse_data(data)
  41. p = self.real_feedback_model.step_probability(fs, data)
  42. logger.debug('step_decison: model probability: {}'.format(str(p)))
  43. pred = np.argmax(p)
  44. real_decision = self.real_feedback_model.model.classes_[pred]
  45. return real_decision
  46. else:
  47. raise ValueError('Neither decision model nor true label are given')
  48. def decision(self, data, true_label=None):
  49. """决策主要方法,输出逻辑如下:
  50. 如果有决策模型,无论是否有true_label,都会使用模型进行一步决策计算并填入buffer(不一定返回)
  51. 如果有true_label(训练模式),产生一个随机数确定本trial是否为假反馈,
  52. 是假反馈,产生一个随机数确定本trial产生正确or错误的假反馈,假反馈的标签为10000
  53. 不是假反馈,使用模型决策
  54. 如果没有true_label(测试模式),直接使用模型决策
  55. 模型决策逻辑:
  56. 根据模型记录的last_state,
  57. 如果当前state和last_state相同,输出-1
  58. 如果当前state和last_state不同,输出当前state
  59. Args:
  60. data (mne.io.RawArray): 数据
  61. true_label (None or int): 训练时假反馈的真实标签
  62. Return:
  63. int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
  64. """
  65. if self.real_feedback_model is not None:
  66. fs, data = self.parse_data(data)
  67. real_decision = self.real_feedback_model.viterbi(fs, data)
  68. # map to unified label
  69. if real_decision != -1:
  70. real_decision = self.real_feedback_model.model.classes_[real_decision]
  71. virtual_feedback = self.virtual_feedback(true_label)
  72. if virtual_feedback is not None:
  73. return virtual_feedback
  74. # true_label is None or not running virtual feedback in this trial
  75. # if no real model, raise ValueError
  76. if self.real_feedback_model is None:
  77. raise ValueError('Neither decision model nor true label are given')
  78. return real_decision
  79. def virtual_feedback(self, true_label=None):
  80. if true_label is not None:
  81. p = random.random()
  82. if p < self.virtual_feedback_rate: # virtual feedback (error rate 0.2)
  83. p_correct = random.random()
  84. if p_correct < 0.8:
  85. return true_label
  86. else:
  87. return 10000
  88. return None
  89. def parse_data(self, data):
  90. fs, event, data_array = data
  91. # do preprocessing
  92. data_array = reref(data_array, self.reref_method)
  93. return fs, data_array
  94. class HMMModel:
  95. """HMMModel 是一个基于隐马尔可夫模型(Hidden Markov Model, HMM)的框架,用于建模状态转移和更新。"""
  96. def __init__(self,
  97. transmat=None,
  98. n_classes=2,
  99. state_trans_prob=0.6,
  100. state_change_threshold=0.5,
  101. momentum=0.5):
  102. """
  103. 初始化HMM模型。
  104. transmat: 状态转移矩阵,如果为 None,则自动生成一个简单的转移矩阵。
  105. n_classes: 状态的数量。
  106. state_trans_prob: 状态保持不变的概率。
  107. state_change_threshold: 状态改变的阈值。
  108. momentum: 用于更新状态概率的动量因子。
  109. """
  110. self.n_classes = n_classes
  111. self.set_current_state(0)
  112. self.state_change_threshold = state_change_threshold
  113. self.hold_state = False
  114. if transmat is None:
  115. # build state transition matrix
  116. self.state_trans_matrix = np.zeros((n_classes, n_classes))
  117. # fill diagonal
  118. np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
  119. # fill 0 -> each state,
  120. self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
  121. self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
  122. else:
  123. if isinstance(transmat, str):
  124. transmat = np.loadtxt(transmat)
  125. self.state_trans_matrix = transmat
  126. # momentum factor
  127. self.momentum = momentum
  128. def set_current_state(self, current_state):
  129. self._last_state = current_state
  130. self._probability = np.zeros(self.n_classes)
  131. self._probability[current_state] = 1.
  132. def step_probability(self, fs, data):
  133. raise NotImplementedError
  134. def offset_updater(self, decision):
  135. if decision == 0:
  136. if self.hold_state:
  137. self.hold_state = False
  138. return 0
  139. else:
  140. self.hold_state = True
  141. return -1
  142. else:
  143. return decision
  144. def viterbi(self, fs, data, return_step_p=False):
  145. """
  146. Interface for class decision
  147. """
  148. p = self.step_probability(fs, data)
  149. if return_step_p:
  150. return p, self.update_state(p)
  151. else:
  152. return self.update_state(p)
  153. def update_state(self, current_p):
  154. # veterbi algorithm
  155. prob = (self.state_trans_matrix * self._probability.T).sum(axis=1) * current_p
  156. # normalize
  157. prob /= np.sum(prob)
  158. # momentum
  159. self._probability = self.momentum * self._probability + (1 - self.momentum) * prob
  160. logger.debug("viterbi probability, {}".format(str(self._probability)))
  161. current_state = np.argmax(self._probability)
  162. if current_state == self._last_state:
  163. return -1
  164. else:
  165. if self._probability[current_state] > self.state_change_threshold:
  166. self.set_current_state(current_state)
  167. return self.offset_updater(current_state)
  168. else:
  169. return -1
  170. @property
  171. def probability(self):
  172. return self._probability.copy()
  173. class ClfEmissionHMM(HMMModel):
  174. """
  175. ClfEmissionHMM 则是 HMMModel 的一个扩展,结合了分类模型的输出作为HMM的发射概率。
  176. """
  177. def __init__(self, model, **kwargs):
  178. """
  179. 初始化分类器发射的HMM模型。
  180. model: 包含特征提取器、嵌入器和分类模型的元组或模型文件路径。
  181. """
  182. if isinstance(model, str):
  183. model = joblib.load(model)
  184. self.feat_extractor, self.embedder, self.model = model
  185. super(ClfEmissionHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
  186. def step_probability(self, fs, data):
  187. p = data_evaluation([self.feat_extractor, self.embedder, self.model], data, fs, None, None, False).squeeze()
  188. return p
  189. def model_loader(model_path, **kwargs):
  190. """
  191. 模型如果存在训练好的transmat,会直接load
  192. """
  193. model_root, model_filename = os.path.dirname(model_path), os.path.basename(model_path)
  194. model_name = model_filename.split('.')[0]
  195. transmat_path = os.path.join(model_root, model_name + '_transmat.txt')
  196. if os.path.isfile(transmat_path):
  197. transmat = np.loadtxt(transmat_path)
  198. else:
  199. transmat = None
  200. kwargs['transmat'] = transmat
  201. return ClfEmissionHMM(model_path, **kwargs)