|
@@ -4,7 +4,7 @@ import random
|
|
|
import logging
|
|
|
import os
|
|
|
from scipy import signal
|
|
|
-from .utils import parse_model_type
|
|
|
+from .utils import parse_model_type, reref
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -21,10 +21,12 @@ class Controller:
|
|
|
"""
|
|
|
def __init__(self,
|
|
|
virtual_feedback_rate=1.,
|
|
|
- real_feedback_model=None):
|
|
|
+ real_feedback_model=None,
|
|
|
+ reref_method='monopolar'):
|
|
|
|
|
|
self.real_feedback_model = real_feedback_model
|
|
|
self.virtual_feedback_rate = virtual_feedback_rate
|
|
|
+ self.reref_method = reref_method
|
|
|
|
|
|
def step_decision(self, data, true_label=None):
|
|
|
"""抓握训练调用接口,只进行单次判决,不涉及马尔可夫过程,
|
|
@@ -41,7 +43,7 @@ class Controller:
|
|
|
return virtual_feedback
|
|
|
|
|
|
if self.real_feedback_model is not None:
|
|
|
- fs, data = self.real_feedback_model.parse_data(data)
|
|
|
+ fs, data = self.parse_data(data)
|
|
|
p = self.real_feedback_model.step_probability(fs, data)
|
|
|
logger.debug('step_decison: model probability: {}'.format(str(p)))
|
|
|
pred = np.argmax(p)
|
|
@@ -71,7 +73,8 @@ class Controller:
|
|
|
int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
|
|
|
"""
|
|
|
if self.real_feedback_model is not None:
|
|
|
- real_decision = self.real_feedback_model.viterbi(data)
|
|
|
+ fs, data = self.parse_data(data)
|
|
|
+ real_decision = self.real_feedback_model.viterbi(fs, data)
|
|
|
# map to unified label
|
|
|
if real_decision != -1:
|
|
|
real_decision = self.real_feedback_model.model.classes_[real_decision]
|
|
@@ -96,10 +99,20 @@ class Controller:
|
|
|
else:
|
|
|
return 10000
|
|
|
return None
|
|
|
+
|
|
|
+ def parse_data(self, data):
|
|
|
+ fs, event, data_array = data
|
|
|
+ # do preprocessing
|
|
|
+ data_array = reref(data_array, self.reref_method)
|
|
|
+ return fs, data_array
|
|
|
|
|
|
|
|
|
class HMMModel:
|
|
|
- def __init__(self, transmat=None, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.5):
|
|
|
+ def __init__(self,
|
|
|
+ transmat=None,
|
|
|
+ n_classes=2,
|
|
|
+ state_trans_prob=0.6,
|
|
|
+ state_change_threshold=0.5):
|
|
|
self.n_classes = n_classes
|
|
|
self.set_current_state(0)
|
|
|
|
|
@@ -124,21 +137,13 @@ class HMMModel:
|
|
|
self._probability[current_state] = 1.
|
|
|
|
|
|
def step_probability(self, fs, data):
|
|
|
- # do preprocessing here
|
|
|
- # common average
|
|
|
- data -= data.mean(axis=0)
|
|
|
- return data
|
|
|
-
|
|
|
- def parse_data(self, data):
|
|
|
- fs, event, data_array = data
|
|
|
- return fs, data_array
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
- def viterbi(self, data, return_step_p=False):
|
|
|
+ def viterbi(self, fs, data, return_step_p=False):
|
|
|
"""
|
|
|
Interface for class decision
|
|
|
|
|
|
"""
|
|
|
- fs, data = self.parse_data(data)
|
|
|
p = self.step_probability(fs, data)
|
|
|
if return_step_p:
|
|
|
return p, self.update_state(p)
|
|
@@ -164,7 +169,6 @@ class HMMModel:
|
|
|
|
|
|
@property
|
|
|
def probability(self):
|
|
|
- # TODO: return each classes
|
|
|
return self._probability.copy()
|
|
|
|
|
|
|
|
@@ -179,7 +183,6 @@ class BaselineHMM(HMMModel):
|
|
|
def step_probability(self, fs, data):
|
|
|
"""Step
|
|
|
"""
|
|
|
- data = super(BaselineHMM, self).step_probability(fs, data)
|
|
|
# filter data
|
|
|
filter_bank_data = self.feat_extractor.transform(data)
|
|
|
# downsampling
|
|
@@ -203,7 +206,6 @@ class RiemannHMM(HMMModel):
|
|
|
def step_probability(self, fs, data):
|
|
|
"""Step
|
|
|
"""
|
|
|
- data = super(RiemannHMM, self).step_probability(fs, data)
|
|
|
data = self.feat_extractor.transform(data)
|
|
|
data = data[None] # pad trial dimension
|
|
|
# scale data
|