Bläddra i källkod

Improve visulization

dk 1 år sedan
förälder
incheckning
2703e28116
1 ändrade filer med 3 tillägg och 5 borttagningar
  1. 3 5
      backend/online_sim.py

+ 3 - 5
backend/online_sim.py

@@ -118,8 +118,7 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
     # hmm
     fig_hmm, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
     axes[0].set_title('True states')
-    axes[0].plot(raw.times, stim_true)
-    axes[0].set_axis_off()
+    bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_true, ax=axes[0])
     axes[1].set_title('State sequence')
     bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_pred, ax=axes[1])
     for i, ax in enumerate(axes[2:]):
@@ -129,8 +128,7 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
     # without hmm
     fig_naive, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
     axes[0].set_title('True states')
-    axes[0].plot(raw.times, stim_true)
-    axes[0].set_axis_off()
+    bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_true, ax=axes[0])
     axes[1].set_title('State sequence')
     bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_pred_naive, ax=axes[1])
     for i, ax in enumerate(axes[2:]):
@@ -227,7 +225,7 @@ if __name__ == '__main__':
     }
     model_hmm = online.model_loader(model_path, **input_kwargs)
 
-    # do validations
+    # do online simulation
     metric_hmm, metric_naive, fig_pred = simulation(raw, 
                                              event_id, 
                                              model=model_hmm,