|
@@ -9,7 +9,7 @@ import yaml
|
|
import os
|
|
import os
|
|
import argparse
|
|
import argparse
|
|
import logging
|
|
import logging
|
|
-from scipy import stats
|
|
|
|
|
|
+from sklearn.metrics import accuracy_score
|
|
from dataloaders import neo
|
|
from dataloaders import neo
|
|
import bci_core.online as online
|
|
import bci_core.online as online
|
|
import bci_core.utils as bci_utils
|
|
import bci_core.utils as bci_utils
|
|
@@ -71,8 +71,55 @@ class DataGenerator:
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
|
|
|
|
|
|
|
|
|
|
+def _evaluation_loop(raw, events, model_hmm, step_length):
|
|
|
|
+ val_data = raw.get_data()
|
|
|
|
+ fs = raw.info['sfreq']
|
|
|
|
+
|
|
|
|
+ data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
|
|
|
|
+
|
|
|
|
+ decision_with_hmm = []
|
|
|
|
+ decision_without_hmm = []
|
|
|
|
+ probs = []
|
|
|
|
+ for time, data in data_gen.loop():
|
|
|
|
+ step_p, cls = model_hmm.viterbi(data, return_step_p=True)
|
|
|
|
+ if cls >=0:
|
|
|
|
+ cls = model_hmm.model.classes_[cls]
|
|
|
|
+ decision_with_hmm.append((time, cls)) # map to unified label
|
|
|
|
+ decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
|
|
|
|
+ probs.append((time, model_hmm.probability))
|
|
|
|
+ probs = np.array(probs)
|
|
|
|
+
|
|
|
|
+ events_pred = _construct_model_event(decision_with_hmm, fs)
|
|
|
|
+ events_pred_naive = _construct_model_event(decision_without_hmm, fs)
|
|
|
|
+
|
|
|
|
+ p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs)
|
|
|
|
+
|
|
|
|
+ p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
|
|
|
|
+
|
|
|
|
+ stim_true = _event_to_stim_channel(events, len(raw.times))
|
|
|
|
+ stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
|
|
|
|
+ stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
|
|
|
|
+
|
|
|
|
+ # TODO: auc
|
|
|
|
+ accu_hmm = accuracy_score(stim_true, stim_pred)
|
|
|
|
+ accu_naive = accuracy_score(stim_true, stim_pred_naive)
|
|
|
|
+
|
|
|
|
+ fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False)
|
|
|
|
+ ax[0].set_title('pred (naive)')
|
|
|
|
+ ax[0].plot(raw.times, stim_pred_naive)
|
|
|
|
+ ax[1].set_title('pred')
|
|
|
|
+ ax[1].plot(raw.times, stim_pred)
|
|
|
|
+ ax[2].set_title('true')
|
|
|
|
+ ax[2].plot(raw.times, stim_true)
|
|
|
|
+ ax[3].set_title('prob')
|
|
|
|
+ ax[3].plot(probs[:, 0], probs[:, 1])
|
|
|
|
+ ax[3].set_ylim([0, 1])
|
|
|
|
+
|
|
|
|
+ return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
|
|
|
|
+
|
|
|
|
+
|
|
def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
|
|
def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
|
|
- """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
|
|
|
|
|
|
+ """模型验证接口,使用指定数据进行验证,绘制ersd map
|
|
Args:
|
|
Args:
|
|
raw (mne.io.Raw)
|
|
raw (mne.io.Raw)
|
|
event_id (dict)
|
|
event_id (dict)
|
|
@@ -99,38 +146,12 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
|
|
|
|
|
|
|
|
|
|
controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
|
|
controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
|
|
|
|
+ model_hmm = controller.real_feedback_model
|
|
|
|
|
|
- # validate with the second half
|
|
|
|
- val_data = raw_val.get_data()
|
|
|
|
- data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
|
|
|
|
-
|
|
|
|
- decisions = []
|
|
|
|
- probs = []
|
|
|
|
- for time, data in data_gen.loop():
|
|
|
|
- cls = controller.decision(data)
|
|
|
|
- decisions.append((time, cls))
|
|
|
|
- probs.append((time, controller.real_feedback_model.probability))
|
|
|
|
- probs = np.array(probs)
|
|
|
|
-
|
|
|
|
- events_pred = _construct_model_event(decisions, fs)
|
|
|
|
-
|
|
|
|
- precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
|
|
|
|
-
|
|
|
|
- stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
|
|
|
|
- stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
|
|
|
|
-
|
|
|
|
- corr, _ = stats.pearsonr(stim_pred, stim_true)
|
|
|
|
-
|
|
|
|
- fig_pred, ax = plt.subplots(3, 1, sharex=True, sharey=False)
|
|
|
|
- ax[0].set_title('pred')
|
|
|
|
- ax[0].plot(raw_val.times, stim_pred)
|
|
|
|
- ax[1].set_title('true')
|
|
|
|
- ax[1].plot(raw_val.times, stim_true)
|
|
|
|
- ax[2].set_title('prob')
|
|
|
|
- ax[2].plot(probs[:, 0], probs[:, 1])
|
|
|
|
- ax[2].set_ylim([0, 1])
|
|
|
|
|
|
+ # run with and without hmm
|
|
|
|
+ fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, events_val, model_hmm, step_length)
|
|
|
|
|
|
- return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
|
|
|
|
|
|
+ return metric_hmm, metric_naive, fig_erds, fig_pred
|
|
|
|
|
|
|
|
|
|
def _construct_model_event(decision_seq, fs):
|
|
def _construct_model_event(decision_seq, fs):
|
|
@@ -164,14 +185,16 @@ if __name__ == '__main__':
|
|
event_id[f] = neo.FINGERMODEL_IDS[f]
|
|
event_id[f] = neo.FINGERMODEL_IDS[f]
|
|
|
|
|
|
# preprocess raw
|
|
# preprocess raw
|
|
- raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
|
|
|
|
|
|
+ trial_duration = config_info['buffer_length']
|
|
|
|
+ raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1], upsampled_epoch_length=trial_duration)
|
|
|
|
|
|
# do validations
|
|
# do validations
|
|
- metrics, fig_erds, fig_pred = validation(raw,
|
|
|
|
|
|
+ metric_hmm, metric_naive, fig_erds, fig_pred = validation(raw,
|
|
event_id,
|
|
event_id,
|
|
model=model_path,
|
|
model=model_path,
|
|
state_change_threshold=args.state_change_threshold,
|
|
state_change_threshold=args.state_change_threshold,
|
|
step_length=config_info['buffer_length'])
|
|
step_length=config_info['buffer_length'])
|
|
fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
|
|
fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|
|
- logger.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')
|
|
|
|
|
|
+ logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
|
|
|
|
+ logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')
|