Browse Source

Merge branch 'riemann' of dk/kraken into master

dk 1 year ago
parent
commit
4a1f3d471b
1 changed files with 23 additions and 10 deletions
  1. 23 10
      backend/validation.py

+ 23 - 10
backend/validation.py

@@ -16,7 +16,7 @@ import bci_core.viz as bci_viz
 from settings.config import settings
 
 
-logging.basicConfig(level=logging.INFO)
+logging.basicConfig(level=logging.DEBUG)
 config_info = settings.CONFIG_INFO
 
 
@@ -71,19 +71,32 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
     # validate with the second half
     val_data = raw_val.get_data()
     data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
-    rets = []
+
+    decisions = []
+    probs = []
     for time, data in data_gen.loop():
         cls = controller.decision(data)
-        rets.append((time, cls))
-    events_pred = _construct_model_event(rets, fs)
+        decisions.append((time, cls))
+        probs.append((time, controller.real_feedback_model.probability))
+    probs = np.array(probs)
+    
+    events_pred = _construct_model_event(decisions, fs)
+
     precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
+
     stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
     stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
-    corr, p = stats.pearsonr(stim_pred, stim_true)
-    fig_pred, ax = plt.subplots(1, 1)
-    ax.plot(raw_val.times, stim_pred, label='pred')
-    ax.plot(raw_val.times, stim_true, label='true')
-    ax.legend()
+
+    corr, _ = stats.pearsonr(stim_pred, stim_true)
+
+    fig_pred, ax = plt.subplots(1, 1, sharex=True, sharey=False)
+    ax[0].set_title('pred')
+    ax[0].plot(raw_val.times, stim_pred)
+    ax[1].set_title('true')
+    ax[1].plot(raw_val.times, stim_true)
+    ax[2].set_title('prob')
+    ax[2].plot(probs[:, 0], probs[:, 1])
+    ax[2].set_ylim([0, 1])
 
     return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
 
@@ -110,7 +123,7 @@ if __name__ == '__main__':
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/'
-    model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-16-43-23.pkl'
+    model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-21-23-15.pkl'
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']