import joblib import numpy as np import random from scipy import signal from .feature_extractors import filterbank_extractor from .utils import parse_model_type class Controller: """在线控制接口 运行时主要调用decision方法, 每次气动手反馈后调用reset_buffer方法,用以跳过气动手不应期 Args: virtual_feedback_rate (float): 0-1之间浮点数,控制假反馈占比 model_path (string): 模型文件路径 buffer_steps (int): """ def __init__(self, virtual_feedback_rate=1., model_path=None, state_change_threshold=0.6): if (model_path is None) or (model_path == 'None'): self.real_feedback_model = None else: self.model_type, _ = parse_model_type(model_path) if self.model_type == 'baseline': self.real_feedback_model = BaselineHMM(model_path, state_change_threshold=state_change_threshold) else: raise NotImplementedError self.virtual_feedback_rate = virtual_feedback_rate def set_real_feedback_model(self, model): self.real_feedback_model = model def step_decision(self, data, true_label=None): """抓握训练调用接口,只进行单次判决,不涉及马尔可夫过程, 假反馈的错误反馈默认输出为10000 Args: data (mne.io.RawArray): 数据 true_label (None or int): 训练时假反馈的真实标签 Return: int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble) """ virtual_feedback = self.virtual_feedback(true_label) if virtual_feedback is not None: return virtual_feedback if self.real_feedback_model is not None: fs, data = self.real_feedback_model.parse_data(data) p = self.real_feedback_model.step_probability(fs, data) pred = np.argmax(p) real_decision = self.real_feedback_model.model.classes_[pred] return real_decision else: raise ValueError('Neither decision model nor true label are given') def decision(self, data, true_label=None): """决策主要方法,输出逻辑如下: 如果有决策模型,无论是否有true_label,都会使用模型进行一步决策计算并填入buffer(不一定返回) 如果有true_label(训练模式),产生一个随机数确定本trial是否为假反馈, 是假反馈,产生一个随机数确定本trial产生正确or错误的假反馈,假反馈的标签为10000 不是假反馈,使用模型决策 如果没有true_label(测试模式),直接使用模型决策 模型决策逻辑: 根据模型记录的last_state, 如果当前state和last_state相同,输出-1 如果当前state和last_state不同,输出当前state Args: data (mne.io.RawArray): 数据 true_label (None or int): 训练时假反馈的真实标签 Return: int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble) """ if self.real_feedback_model is not None: real_decision = self.real_feedback_model.verterbi(data) # map to unified label if real_decision != -1: real_decision = self.real_feedback_model.model.classes_[real_decision] virtual_feedback = self.virtual_feedback(true_label) if virtual_feedback is not None: return virtual_feedback # true_label is None or not running virtual feedback in this trial # if no real model, raise ValueError if self.real_feedback_model is None: raise ValueError('Neither decision model nor true label are given') return real_decision def virtual_feedback(self, true_label=None): if true_label is not None: p = random.random() if p < self.virtual_feedback_rate: # virtual feedback (error rate 0.2) p_correct = random.random() if p_correct < 0.8: return true_label else: return 10000 return None def reset_buffer(self): # call after every real feedback if self.real_feedback_model is not None: self.real_feedback_model.reset_prob() class HMMModel: def __init__(self, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.7): self.n_classes = n_classes self._probability = np.ones(n_classes) / n_classes self._last_state = 0 self.state_change_threshold = state_change_threshold # TODO: train with daily use data # build state transition matrix self.state_trans_matrix = np.zeros((n_classes, n_classes)) # fill diagonal np.fill_diagonal(self.state_trans_matrix, state_trans_prob) # fill 0 -> each state, self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1) self.state_trans_matrix[1:, 0] = 1 - state_trans_prob def reset_state(self): self._last_state = 0 self._probability = np.ones(self.n_classes) / self.n_classes def set_current_state(self, current_state): self._last_state = current_state self._probability = np.zeros(self.n_classes) self._probability[current_state] = 1 def step_probability(self, fs, data): raise NotImplementedError def parse_data(self, data): fs, event, data_array = data return fs, data_array def verterbi(self, data): """ Interface for class decision """ fs, data = self.parse_data(data) p = self.step_probability(fs, data) return self.update_state(p) def update_state(self, current_p): # veterbi algorithm self._probability = (self.state_trans_matrix * self._probability.T).sum(axis=1) * current_p # normalize self._probability /= np.sum(self._probability) current_state = np.argmax(self._probability) if current_state == self._last_state: return -1 else: if self._probability[current_state] > self.state_change_threshold: self.set_current_state(current_state) return current_state else: return -1 class BaselineHMM(HMMModel): def __init__(self, model, **kwargs): if isinstance(model, str): self.model = joblib.load(model) else: self.model = model self.freqs = np.arange(20, 150, 15) super(BaselineHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs) def step_probability(self, fs, data): """Step """ # filter data filter_bank_data = filterbank_extractor(data, fs, self.freqs, reshape_freqs_dim=True) # downsampling decimate_rate = np.sqrt(fs / 10).astype(np.int16) filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True) filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True) # predict proba p = self.model.predict_proba(filter_bank_data[None]).squeeze() return p class RiemannHMM(HMMModel): def __init__(self, model, **kwargs): if isinstance(model, str): self.feat_extractor, self.scaler, self.cov, self.model = joblib.load(model) else: self.feat_extractor, self.scaler, self.cov, self.model = model super(RiemannHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs) def step_probability(self, fs, data): """Step """ data = self.feat_extractor.transform(data) # scale data data = self.scaler.transform(data) # compute cov data = self.cov.transform(data) # predict proba p = self.model.predict_proba(data[None]).squeeze() return p