|
@@ -13,6 +13,7 @@ from sklearn.metrics import accuracy_score
|
|
|
from dataloaders import neo
|
|
|
import bci_core.online as online
|
|
|
import bci_core.utils as bci_utils
|
|
|
+import bci_core.viz as bci_viz
|
|
|
from settings.config import settings
|
|
|
|
|
|
|
|
@@ -87,14 +88,18 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
|
|
|
decision_with_hmm = []
|
|
|
decision_without_hmm = []
|
|
|
probs = []
|
|
|
+ probs_naive = []
|
|
|
for time, data in data_gen.loop(step_length):
|
|
|
step_p, cls = model_hmm.viterbi(data, return_step_p=True)
|
|
|
if cls >=0:
|
|
|
cls = model_hmm.model.classes_[cls]
|
|
|
decision_with_hmm.append((time, cls)) # map to unified label
|
|
|
decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
|
|
|
- probs.append((time, model_hmm.probability))
|
|
|
+ probs.append(model_hmm.probability)
|
|
|
+ # TODO: match multiclass
|
|
|
+ probs_naive.append(step_p[1])
|
|
|
probs = np.array(probs)
|
|
|
+ probs_naive = np.array(probs_naive)
|
|
|
|
|
|
events_pred = _construct_model_event(decision_with_hmm, fs)
|
|
|
events_pred_naive = _construct_model_event(decision_without_hmm, fs)
|
|
@@ -110,16 +115,19 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
|
|
|
accu_hmm = accuracy_score(stim_true, stim_pred)
|
|
|
accu_naive = accuracy_score(stim_true, stim_pred_naive)
|
|
|
|
|
|
- fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False)
|
|
|
- ax[0].set_title('pred (naive)')
|
|
|
- ax[0].plot(raw.times, stim_pred_naive)
|
|
|
- ax[1].set_title('pred')
|
|
|
- ax[1].plot(raw.times, stim_pred)
|
|
|
- ax[2].set_title('true')
|
|
|
- ax[2].plot(raw.times, stim_true)
|
|
|
- ax[3].set_title('prob')
|
|
|
- ax[3].plot(probs[:, 0], probs[:, 1])
|
|
|
- ax[3].set_ylim([0, 1])
|
|
|
+ # TODO: match multiclass (one-hot)
|
|
|
+ fig_pred, axes = plt.subplots(5, 1, sharex=True, figsize=(10, 8))
|
|
|
+ axes[0].set_title('True states')
|
|
|
+ axes[0].plot(raw.times, stim_true)
|
|
|
+ axes[0].set_axis_off()
|
|
|
+ axes[1].set_title('With HMM (probs)')
|
|
|
+ bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, probs, ax=axes[1])
|
|
|
+ axes[2].set_title('Without HMM (probs)')
|
|
|
+ bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, probs_naive, ax=axes[2])
|
|
|
+ axes[3].set_title('With HMM')
|
|
|
+ bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, stim_pred, ax=axes[3])
|
|
|
+ axes[4].set_title('Without HMM')
|
|
|
+ bci_viz.plot_state_seq_with_cue((raw.times[0], raw.times[-1]), stim_true, stim_pred_naive, ax=axes[4])
|
|
|
|
|
|
return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
|
|
|
|
|
@@ -155,12 +163,24 @@ def simulation(raw_val, event_id, model,
|
|
|
return metric_hmm, metric_naive, fig_pred
|
|
|
|
|
|
|
|
|
-def _construct_model_event(decision_seq, fs):
|
|
|
- events = []
|
|
|
- for i in decision_seq:
|
|
|
- time, cls = i
|
|
|
- if cls >= 0:
|
|
|
- events.append([int(time * fs), 0, cls])
|
|
|
+def _construct_model_event(decision_seq, fs, start_cond=0):
|
|
|
+ def _filter_seq(decision_seq):
|
|
|
+ new_seq = [(decision_seq[0][0], start_cond)]
|
|
|
+ for i in range(1, len(decision_seq)):
|
|
|
+ if decision_seq[i][1] == -1:
|
|
|
+ new_seq.append((decision_seq[i][0], new_seq[-1][1]))
|
|
|
+ else:
|
|
|
+ new_seq.append(decision_seq[i])
|
|
|
+ return new_seq
|
|
|
+ decision_seq = _filter_seq(decision_seq)
|
|
|
+
|
|
|
+ last_state = decision_seq[0][1]
|
|
|
+ events = [(int(decision_seq[0][0] * fs), 0, last_state)]
|
|
|
+ for i in range(1, len(decision_seq)):
|
|
|
+ time, label = decision_seq[i]
|
|
|
+ if label != last_state:
|
|
|
+ last_state = label
|
|
|
+ events.append([int(time * fs), 0, label])
|
|
|
return np.array(events)
|
|
|
|
|
|
|
|
@@ -209,3 +229,5 @@ if __name__ == '__main__':
|
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|
|
|
logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
|
|
|
logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')
|
|
|
+
|
|
|
+ plt.show()
|