|
@@ -118,8 +118,7 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
|
|
|
# hmm
|
|
|
fig_hmm, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
|
|
|
axes[0].set_title('True states')
|
|
|
- axes[0].plot(raw.times, stim_true)
|
|
|
- axes[0].set_axis_off()
|
|
|
+ bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_true, ax=axes[0])
|
|
|
axes[1].set_title('State sequence')
|
|
|
bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_pred, ax=axes[1])
|
|
|
for i, ax in enumerate(axes[2:]):
|
|
@@ -129,8 +128,7 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
|
|
|
# without hmm
|
|
|
fig_naive, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
|
|
|
axes[0].set_title('True states')
|
|
|
- axes[0].plot(raw.times, stim_true)
|
|
|
- axes[0].set_axis_off()
|
|
|
+ bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_true, ax=axes[0])
|
|
|
axes[1].set_title('State sequence')
|
|
|
bci_viz.plot_states((raw.times[0], raw.times[-1]), stim_pred_naive, ax=axes[1])
|
|
|
for i, ax in enumerate(axes[2:]):
|
|
@@ -227,7 +225,7 @@ if __name__ == '__main__':
|
|
|
}
|
|
|
model_hmm = online.model_loader(model_path, **input_kwargs)
|
|
|
|
|
|
- # do validations
|
|
|
+ # do online simulation
|
|
|
metric_hmm, metric_naive, fig_pred = simulation(raw,
|
|
|
event_id,
|
|
|
model=model_hmm,
|