Parcourir la source

Feat: smooth p

dk il y a 1 an
Parent
commit
a7b3508679
1 fichiers modifiés avec 17 ajouts et 1 suppressions
  1. 17 1
      backend/bci_core/online.py

+ 17 - 1
backend/bci_core/online.py

@@ -119,6 +119,10 @@ class HMMModel:
                 transmat = np.loadtxt(transmat)
             self.state_trans_matrix = transmat
 
+        # emission probability moving average, (5 steps)
+        self._filter_b = np.ones(5) / 5
+        self._z = np.zeros((len(self._filter_b) - 1, n_classes))
+
     def reset_state(self):
         self._probability[0] = 1.
         self._last_state = 0
@@ -138,6 +142,16 @@ class HMMModel:
         fs, event, data_array = data
         return fs, data_array
     
+    def filter_prob(self, probs):
+        """
+        Args: 
+            probs (np.ndarray): (n_classes,)
+        Returns:
+            filtered_probs (np.ndarray): (n_classes,)
+        """
+        filtered_probs, self._z = signal.lfilter(self._filter_b, 1, probs[None], axis=0, zi=self._z)
+        return filtered_probs.squeeze()
+    
     def viterbi(self, data, return_step_p=False):
         """
             Interface for class decision
@@ -145,6 +159,9 @@ class HMMModel:
         """
         fs, data = self.parse_data(data)
         p = self.step_probability(fs, data)
+        # smooth p
+        p = self.filter_prob(p)
+
         if return_step_p:
             return p, self.update_state(p)
         else:
@@ -169,7 +186,6 @@ class HMMModel:
     
     @property
     def probability(self):
-        # TODO: return each classes
         return self._probability.copy()