Browse Source

Fix: 对step decision的event评估bug,Feat: 状态转移可视化

dk 1 year ago
parent
commit
a83b184261
4 changed files with 66 additions and 33 deletions
  1. 1 0
      backend/bci_core/online.py
  2. 9 0
      backend/bci_core/viz.py
  3. 39 17
      backend/online_sim.py
  4. 17 16
      backend/tests/test_validation.py

+ 1 - 0
backend/bci_core/online.py

@@ -169,6 +169,7 @@ class HMMModel:
     
     @property
     def probability(self):
+        # TODO: return each classes
         return np.max(self._probability[1:])  # largest prob except the rest state
 
 

+ 9 - 0
backend/bci_core/viz.py

@@ -118,3 +118,12 @@ def plot_confusion_matrix(y_true, y_pred):
     disp = ConfusionMatrixDisplay(cm)
     disp.plot()
     return disp.figure_
+
+
+def plot_state_seq_with_cue(time_range, true_states, pred_probs, ax):
+    # normalize
+    pred_probs /= pred_probs.max()
+    ax.plot(np.linspace(*time_range, len(pred_probs)), pred_probs)
+    true_states = (true_states > 0)
+    ax.fill_between(np.linspace(*time_range, len(true_states)), true_states, where=(true_states > 0), color='gray', alpha=0.6)
+    return ax

+ 39 - 17
backend/online_sim.py

@@ -13,6 +13,7 @@ from sklearn.metrics import accuracy_score
 from dataloaders import neo
 import bci_core.online as online
 import bci_core.utils as bci_utils
+import bci_core.viz as bci_viz
 from settings.config import settings
 
 
@@ -87,14 +88,18 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
     decision_with_hmm = []
     decision_without_hmm = []
     probs = []
+    probs_naive = []
     for time, data in data_gen.loop(step_length):
         step_p, cls = model_hmm.viterbi(data, return_step_p=True)
         if cls >=0:
             cls = model_hmm.model.classes_[cls]
         decision_with_hmm.append((time, cls))  # map to unified label
         decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
-        probs.append((time, model_hmm.probability))
+        probs.append(model_hmm.probability)
+        # TODO: match multiclass
+        probs_naive.append(step_p[1])
     probs = np.array(probs)
+    probs_naive = np.array(probs_naive)
     
     events_pred = _construct_model_event(decision_with_hmm, fs)
     events_pred_naive = _construct_model_event(decision_without_hmm, fs)
@@ -110,16 +115,19 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
     accu_hmm = accuracy_score(stim_true, stim_pred)
     accu_naive = accuracy_score(stim_true, stim_pred_naive)
 
-    fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False)
-    ax[0].set_title('pred (naive)')
-    ax[0].plot(raw.times, stim_pred_naive)
-    ax[1].set_title('pred')
-    ax[1].plot(raw.times, stim_pred)
-    ax[2].set_title('true')
-    ax[2].plot(raw.times, stim_true)
-    ax[3].set_title('prob')
-    ax[3].plot(probs[:, 0], probs[:, 1])
-    ax[3].set_ylim([0, 1])
+    # TODO: match multiclass (one-hot)
+    fig_pred, axes = plt.subplots(5, 1, sharex=True, figsize=(10, 8))
+    axes[0].set_title('True states')
+    axes[0].plot(raw.times, stim_true)
+    axes[0].set_axis_off()
+    axes[1].set_title('With HMM (probs)')
+    bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, probs, ax=axes[1])
+    axes[2].set_title('Without HMM (probs)')
+    bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, probs_naive, ax=axes[2])
+    axes[3].set_title('With HMM')
+    bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, stim_pred, ax=axes[3])
+    axes[4].set_title('Without HMM')
+    bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, stim_pred_naive, ax=axes[4])
 
     return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
 
@@ -155,12 +163,24 @@ def simulation(raw_val, event_id, model,
     return metric_hmm, metric_naive, fig_pred
 
 
-def _construct_model_event(decision_seq, fs):
-    events = []
-    for i in decision_seq:
-        time, cls = i
-        if cls >= 0:
-            events.append([int(time * fs), 0, cls])
+def _construct_model_event(decision_seq, fs, start_cond=0):
+    def _filter_seq(decision_seq):
+        new_seq = [(decision_seq[0][0], start_cond)]
+        for i in range(1, len(decision_seq)):
+            if decision_seq[i][1] == -1:
+                new_seq.append((decision_seq[i][0], new_seq[-1][1]))
+            else:
+                new_seq.append(decision_seq[i])
+        return new_seq
+    decision_seq = _filter_seq(decision_seq)
+
+    last_state = decision_seq[0][1]
+    events = [(int(decision_seq[0][0] * fs), 0, last_state)]
+    for i in range(1, len(decision_seq)):
+        time, label = decision_seq[i]
+        if label != last_state:
+            last_state = label
+            events.append([int(time * fs), 0, label])
     return np.array(events)
 
 
@@ -209,3 +229,5 @@ if __name__ == '__main__':
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
     logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
     logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')
+
+    plt.show()

+ 17 - 16
backend/tests/test_validation.py

@@ -4,11 +4,13 @@ import numpy as np
 from glob import glob
 import shutil
 
+import mne
+
 from bci_core import utils as ana_utils
 from bci_core.online import model_loader
 from training import train_model, model_saver
 from dataloaders import neo
-from online_sim import simulation
+from online_sim import simulation, _construct_model_event
 from validation import val_by_epochs
 
 
@@ -17,18 +19,8 @@ class TestOnlineSim(unittest.TestCase):
     def setUpClass(cls):
         root_path = './tests/data'
 
-        raw, cls.event_id = neo.raw_loader(root_path, {'flex': ['1', '2']})
-        cls.raw = raw
-        # split into 2 pieces
-        t_min, t_max = raw.times[0], raw.times[-1]
-        t_mid = raw.times[len(raw.times) // 2]
-        raw_train = raw.copy().crop(tmin=t_min, tmax=t_mid, include_tmax=True)
-        cls.raw_val = raw.copy().crop(tmin=t_mid, tmax=t_max)
-
-        # reconstruct single event for validation
-        if cls.raw_val.annotations.onset[0] > t_mid:
-            # correct time by first timestamp
-            cls.raw_val.annotations.onset -= t_mid
+        raw_train, cls.event_id = neo.raw_loader(root_path, {'flex': ['1']})
+        cls.raw_val, _ = neo.raw_loader(root_path, {'flex': ['2']}, upsampled_epoch_length=None)
         
         # train with the first half
         model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')
@@ -49,14 +41,23 @@ class TestOnlineSim(unittest.TestCase):
         self.assertEqual(precision, 1 / 2)
         self.assertEqual(recall, 1)
 
+    def test_construct_event(self):
+        seq_1 = [(1, -1), (2, -1), (3, -1), (4, 1)]
+        seq_2 = [(1, 0), (2, 0), (4, 1)]
+        gt = [[1, 0, 0], [4, 0, 1]]
+        ret_ = _construct_model_event(seq_1, 1, start_cond=0)
+        self.assertTrue(np.allclose(gt, ret_))
+        ret_ = _construct_model_event(seq_2, 1, start_cond=0)
+        self.assertTrue(np.allclose(gt, ret_))
+
     def test_sim(self):
         model = model_loader(self.model_path, 
                              state_change_threshold=0.7,
                              state_trans_prob=0.7)
-        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=model, epoch_length=1., step_length=0.1)
+        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw_val, self.event_id, model=model, epoch_length=1., step_length=0.1)
         fig_pred.savefig('./tests/data/pred.pdf')   
-        self.assertTrue(metric_hmm[-2] > 0.3)  # f1-score (with hmm)
-        self.assertTrue(metric_nohmm[-2] < 0.15)  # f1-score (without hmm)
+        self.assertTrue(metric_hmm[-2] > 0.7)  # f1-score (with hmm)
+        self.assertTrue(metric_nohmm[-2] < 0.4)  # f1-score (without hmm)
     
     def test_val_model(self):
         metrices, fig_conf = val_by_epochs(self.raw_val, self.model_path, self.event_id, 1.)