Forráskód Böngészése

更新模型评估

dk 1 éve
szülő
commit
076cf9107a
3 módosított fájl, 67 hozzáadás és 41 törlés
  1. 6 3
      backend/bci_core/online.py
  2. 3 3
      backend/tests/test_validation.py
  3. 58 35
      backend/validation.py

+ 6 - 3
backend/bci_core/online.py

@@ -80,7 +80,7 @@ class Controller:
             int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
         """
         if self.real_feedback_model is not None:
-            real_decision = self.real_feedback_model.verterbi(data)
+            real_decision = self.real_feedback_model.viterbi(data)
             # map to unified label
             if real_decision != -1:
                 real_decision = self.real_feedback_model.model.classes_[real_decision]
@@ -143,14 +143,17 @@ class HMMModel:
         fs, event, data_array = data
         return fs, data_array
     
-    def verterbi(self, data):
+    def viterbi(self, data, return_step_p=False):
         """
             Interface for class decision
 
         """
         fs, data = self.parse_data(data)
         p = self.step_probability(fs, data)
-        return self.update_state(p)
+        if return_step_p:
+            return p, self.update_state(p)
+        else:
+            return self.update_state(p)
     
     def update_state(self, current_p):
         # veterbi algorithm

+ 3 - 3
backend/tests/test_validation.py

@@ -50,12 +50,12 @@ class TestValidation(unittest.TestCase):
         self.assertEqual(recall, 1)
 
     def test_validation(self):
-        (precision, recall, f1_score, r), fig_erds, fig_pred = validation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
+        metric_hmm, metric_nohmm, fig_erds, fig_pred = validation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
         fig_erds.savefig('./tests/data/erds.pdf')
         fig_pred.savefig('./tests/data/pred.pdf')   
 
-        self.assertTrue(f1_score > 0.9)
-        self.assertTrue(r > 0.5)
+        self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
+        self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
 
 
 if __name__ == '__main__':

+ 58 - 35
backend/validation.py

@@ -9,7 +9,7 @@ import yaml
 import os
 import argparse
 import logging
-from scipy import stats
+from sklearn.metrics import accuracy_score
 from dataloaders import neo
 import bci_core.online as online
 import bci_core.utils as bci_utils
@@ -71,8 +71,55 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
 
 
+def _evaluation_loop(raw, events, model_hmm, step_length):
+    val_data = raw.get_data()
+    fs = raw.info['sfreq']
+
+    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
+
+    decision_with_hmm = []
+    decision_without_hmm = []
+    probs = []
+    for time, data in data_gen.loop():
+        step_p, cls = model_hmm.viterbi(data, return_step_p=True)
+        if cls >=0:
+            cls = model_hmm.model.classes_[cls]
+        decision_with_hmm.append((time, cls))  # map to unified label
+        decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
+        probs.append((time, model_hmm.probability))
+    probs = np.array(probs)
+    
+    events_pred = _construct_model_event(decision_with_hmm, fs)
+    events_pred_naive = _construct_model_event(decision_without_hmm, fs)
+
+    p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs)
+
+    p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
+
+    stim_true = _event_to_stim_channel(events, len(raw.times))
+    stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
+    stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
+
+    # TODO: auc
+    accu_hmm = accuracy_score(stim_true, stim_pred)
+    accu_naive = accuracy_score(stim_true, stim_pred_naive)
+
+    fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False)
+    ax[0].set_title('pred (naive)')
+    ax[0].plot(raw.times, stim_pred_naive)
+    ax[1].set_title('pred')
+    ax[1].plot(raw.times, stim_pred)
+    ax[2].set_title('true')
+    ax[2].plot(raw.times, stim_true)
+    ax[3].set_title('prob')
+    ax[3].plot(probs[:, 0], probs[:, 1])
+    ax[3].set_ylim([0, 1])
+
+    return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
+
+
 def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
-    """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
+    """模型验证接口,使用指定数据进行验证,绘制ersd map
     Args:
         raw (mne.io.Raw)
         event_id (dict)
@@ -99,38 +146,12 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
     
     
     controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
+    model_hmm = controller.real_feedback_model
 
-    # validate with the second half
-    val_data = raw_val.get_data()
-    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
-
-    decisions = []
-    probs = []
-    for time, data in data_gen.loop():
-        cls = controller.decision(data)
-        decisions.append((time, cls))
-        probs.append((time, controller.real_feedback_model.probability))
-    probs = np.array(probs)
-    
-    events_pred = _construct_model_event(decisions, fs)
-
-    precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
-
-    stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
-    stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
-
-    corr, _ = stats.pearsonr(stim_pred, stim_true)
-
-    fig_pred, ax = plt.subplots(3, 1, sharex=True, sharey=False)
-    ax[0].set_title('pred')
-    ax[0].plot(raw_val.times, stim_pred)
-    ax[1].set_title('true')
-    ax[1].plot(raw_val.times, stim_true)
-    ax[2].set_title('prob')
-    ax[2].plot(probs[:, 0], probs[:, 1])
-    ax[2].set_ylim([0, 1])
+    # run with and without hmm
+    fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, events_val, model_hmm, step_length)
 
-    return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
+    return metric_hmm, metric_naive, fig_erds, fig_pred
 
 
 def _construct_model_event(decision_seq, fs):
@@ -164,14 +185,16 @@ if __name__ == '__main__':
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    trial_duration = config_info['buffer_length']
+    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1], upsampled_epoch_length=trial_duration)
 
     # do validations
-    metrics, fig_erds, fig_pred = validation(raw, 
+    metric_hmm, metric_naive, fig_erds, fig_pred = validation(raw, 
                                              event_id, 
                                              model=model_path, 
                                              state_change_threshold=args.state_change_threshold,
                                              step_length=config_info['buffer_length'])
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
-    logger.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')
+    logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
+    logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')