|
@@ -119,6 +119,10 @@ class HMMModel:
|
|
transmat = np.loadtxt(transmat)
|
|
transmat = np.loadtxt(transmat)
|
|
self.state_trans_matrix = transmat
|
|
self.state_trans_matrix = transmat
|
|
|
|
|
|
|
|
+ # emission probability moving average, (5 steps)
|
|
|
|
+ self._filter_b = np.ones(5) / 5
|
|
|
|
+ self._z = np.zeros((len(self._filter_b) - 1, n_classes))
|
|
|
|
+
|
|
def reset_state(self):
|
|
def reset_state(self):
|
|
self._probability[0] = 1.
|
|
self._probability[0] = 1.
|
|
self._last_state = 0
|
|
self._last_state = 0
|
|
@@ -138,6 +142,16 @@ class HMMModel:
|
|
fs, event, data_array = data
|
|
fs, event, data_array = data
|
|
return fs, data_array
|
|
return fs, data_array
|
|
|
|
|
|
|
|
+ def filter_prob(self, probs):
|
|
|
|
+ """
|
|
|
|
+ Args:
|
|
|
|
+ probs (np.ndarray): (n_classes,)
|
|
|
|
+ Returns:
|
|
|
|
+ filtered_probs (np.ndarray): (n_classes,)
|
|
|
|
+ """
|
|
|
|
+ filtered_probs, self._z = signal.lfilter(self._filter_b, 1, probs[None], axis=0, zi=self._z)
|
|
|
|
+ return filtered_probs.squeeze()
|
|
|
|
+
|
|
def viterbi(self, data, return_step_p=False):
|
|
def viterbi(self, data, return_step_p=False):
|
|
"""
|
|
"""
|
|
Interface for class decision
|
|
Interface for class decision
|
|
@@ -145,6 +159,9 @@ class HMMModel:
|
|
"""
|
|
"""
|
|
fs, data = self.parse_data(data)
|
|
fs, data = self.parse_data(data)
|
|
p = self.step_probability(fs, data)
|
|
p = self.step_probability(fs, data)
|
|
|
|
+ # smooth p
|
|
|
|
+ p = self.filter_prob(p)
|
|
|
|
+
|
|
if return_step_p:
|
|
if return_step_p:
|
|
return p, self.update_state(p)
|
|
return p, self.update_state(p)
|
|
else:
|
|
else:
|
|
@@ -169,7 +186,6 @@ class HMMModel:
|
|
|
|
|
|
@property
|
|
@property
|
|
def probability(self):
|
|
def probability(self):
|
|
- # TODO: return each classes
|
|
|
|
return self._probability.copy()
|
|
return self._probability.copy()
|
|
|
|
|
|
|
|
|