Przeglądaj źródła

Merge branch 'hmm-evaluation' of dk/kraken into master

dk 1 rok temu
rodzic
commit
31b09e3178

+ 6 - 3
backend/bci_core/online.py

@@ -80,7 +80,7 @@ class Controller:
             int: 统一化标签 (-1: keep, 0: rest, 1: cylinder, 2: ball, 3: flex, 4: double, 5: treble)
         """
         if self.real_feedback_model is not None:
-            real_decision = self.real_feedback_model.verterbi(data)
+            real_decision = self.real_feedback_model.viterbi(data)
             # map to unified label
             if real_decision != -1:
                 real_decision = self.real_feedback_model.model.classes_[real_decision]
@@ -143,14 +143,17 @@ class HMMModel:
         fs, event, data_array = data
         return fs, data_array
     
-    def verterbi(self, data):
+    def viterbi(self, data, return_step_p=False):
         """
             Interface for class decision
 
         """
         fs, data = self.parse_data(data)
         p = self.step_probability(fs, data)
-        return self.update_state(p)
+        if return_step_p:
+            return p, self.update_state(p)
+        else:
+            return self.update_state(p)
     
     def update_state(self, current_p):
         # veterbi algorithm

+ 4 - 4
backend/dataloaders/neo.py

@@ -26,7 +26,7 @@ def raw_preprocessing(data_root, session_paths:dict,
         subj_root: 
         session_paths: dict of lists
         do_rereference (bool): do common average rereference or not
-        upsampled_epoch_length: 
+        upsampled_epoch_length (None or float): None: do not do upsampling
         ori_epoch_length (int or 'varied'): original epoch length in second
         unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
         mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
@@ -54,9 +54,9 @@ def raw_preprocessing(data_root, session_paths:dict,
                                     rest_trial_ind=rest_trial_ind,
                                     trial_duration=trial_duration, 
                                     use_original_label=not unify_label)
-        
-        events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
-        annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
+        if upsampled_epoch_length is not None:
+            events = upsample_events(events, int(fs * upsampled_epoch_length))
+        annotations = mne.annotations_from_events(events, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
         raw.set_annotations(annotations)
         raws.append(raw)
 

+ 3 - 3
backend/tests/test_validation.py

@@ -50,12 +50,12 @@ class TestValidation(unittest.TestCase):
         self.assertEqual(recall, 1)
 
     def test_validation(self):
-        (precision, recall, f1_score, r), fig_erds, fig_pred = validation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
+        metric_hmm, metric_nohmm, fig_erds, fig_pred = validation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
         fig_erds.savefig('./tests/data/erds.pdf')
         fig_pred.savefig('./tests/data/pred.pdf')   
 
-        self.assertTrue(f1_score > 0.9)
-        self.assertTrue(r > 0.5)
+        self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
+        self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
 
 
 if __name__ == '__main__':

+ 79 - 39
backend/validation.py

@@ -9,7 +9,7 @@ import yaml
 import os
 import argparse
 import logging
-from scipy import stats
+from sklearn.metrics import accuracy_score
 from dataloaders import neo
 import bci_core.online as online
 import bci_core.utils as bci_utils
@@ -71,14 +71,65 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
 
 
-def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
-    """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
+def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
+    val_data = raw.get_data()
+    fs = raw.info['sfreq']
+
+    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
+
+    decision_with_hmm = []
+    decision_without_hmm = []
+    probs = []
+    for time, data in data_gen.loop():
+        step_p, cls = model_hmm.viterbi(data, return_step_p=True)
+        if cls >=0:
+            cls = model_hmm.model.classes_[cls]
+        decision_with_hmm.append((time, cls))  # map to unified label
+        decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
+        probs.append((time, model_hmm.probability))
+    probs = np.array(probs)
+    
+    events_pred = _construct_model_event(decision_with_hmm, fs)
+    events_pred_naive = _construct_model_event(decision_without_hmm, fs)
+
+    p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs)
+
+    p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
+
+    stim_true = _event_to_stim_channel(events, len(raw.times), trial_length=int(event_trial_length * fs))
+    stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
+    stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
+
+    # TODO: auc
+    accu_hmm = accuracy_score(stim_true, stim_pred)
+    accu_naive = accuracy_score(stim_true, stim_pred_naive)
+
+    fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False)
+    ax[0].set_title('pred (naive)')
+    ax[0].plot(raw.times, stim_pred_naive)
+    ax[1].set_title('pred')
+    ax[1].plot(raw.times, stim_pred)
+    ax[2].set_title('true')
+    ax[2].plot(raw.times, stim_true)
+    ax[3].set_title('prob')
+    ax[3].plot(probs[:, 0], probs[:, 1])
+    ax[3].set_ylim([0, 1])
+
+    return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
+
+
+def validation(raw_val, event_id, model, 
+               state_change_threshold=0.8, 
+               step_length=1., 
+               event_trial_length=5.):
+    """模型验证接口,使用指定数据进行验证,绘制ersd map
     Args:
         raw (mne.io.Raw)
         event_id (dict)
         model: validate existing model, 
         state_change_threshold (float): default 0.8
         step_length (float): batch data step length, default 1. (s)
+        event_trial_length (float): 
 
     Returns:
         None
@@ -99,38 +150,16 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
     
     
     controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
+    model_hmm = controller.real_feedback_model
 
-    # validate with the second half
-    val_data = raw_val.get_data()
-    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
-
-    decisions = []
-    probs = []
-    for time, data in data_gen.loop():
-        cls = controller.decision(data)
-        decisions.append((time, cls))
-        probs.append((time, controller.real_feedback_model.probability))
-    probs = np.array(probs)
-    
-    events_pred = _construct_model_event(decisions, fs)
-
-    precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
-
-    stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
-    stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
-
-    corr, _ = stats.pearsonr(stim_pred, stim_true)
-
-    fig_pred, ax = plt.subplots(3, 1, sharex=True, sharey=False)
-    ax[0].set_title('pred')
-    ax[0].plot(raw_val.times, stim_pred)
-    ax[1].set_title('true')
-    ax[1].plot(raw_val.times, stim_true)
-    ax[2].set_title('prob')
-    ax[2].plot(probs[:, 0], probs[:, 1])
-    ax[2].set_ylim([0, 1])
+    # run with and without hmm
+    fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
+                                                          events_val, 
+                                                          model_hmm, 
+                                                          step_length, 
+                                                          event_trial_length=event_trial_length)
 
-    return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
+    return metric_hmm, metric_naive, fig_erds, fig_pred
 
 
 def _construct_model_event(decision_seq, fs):
@@ -142,10 +171,13 @@ def _construct_model_event(decision_seq, fs):
     return np.array(events)
 
 
-def _event_to_stim_channel(events, time_length):
+def _event_to_stim_channel(events, time_length, trial_length=None):
     x = np.zeros(time_length)
     for i in range(0, len(events) - 1):
-        x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
+        if trial_length is not None:
+            x[events[i, 0]: events[i, 0] + trial_length] = events[i, 2]
+        else:
+            x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
     return x
 
 
@@ -164,14 +196,22 @@ if __name__ == '__main__':
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    trial_time = 5.
+    raw = neo.raw_preprocessing(data_dir, sessions, 
+                                unify_label=True, 
+                                ori_epoch_length=trial_time, 
+                                mov_trial_ind=[2], 
+                                rest_trial_ind=[1], 
+                                upsampled_epoch_length=None)
 
     # do validations
-    metrics, fig_erds, fig_pred = validation(raw, 
+    metric_hmm, metric_naive, fig_erds, fig_pred = validation(raw, 
                                              event_id, 
                                              model=model_path, 
                                              state_change_threshold=args.state_change_threshold,
-                                             step_length=config_info['buffer_length'])
+                                             step_length=config_info['buffer_length'],
+                                             event_trial_length=trial_time)
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
-    logger.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')
+    logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
+    logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')