|
@@ -71,7 +71,7 @@ class DataGenerator:
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
|
|
|
|
|
|
|
|
-def _evaluation_loop(raw, events, model_hmm, step_length):
|
|
|
|
|
|
+def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
|
|
val_data = raw.get_data()
|
|
val_data = raw.get_data()
|
|
fs = raw.info['sfreq']
|
|
fs = raw.info['sfreq']
|
|
|
|
|
|
@@ -96,7 +96,7 @@ def _evaluation_loop(raw, events, model_hmm, step_length):
|
|
|
|
|
|
p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
|
|
p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
|
|
|
|
|
|
- stim_true = _event_to_stim_channel(events, len(raw.times))
|
|
|
|
|
|
+ stim_true = _event_to_stim_channel(events, len(raw.times), trial_length=int(event_trial_length * fs))
|
|
stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
|
|
stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
|
|
stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
|
|
stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
|
|
|
|
|
|
@@ -118,7 +118,10 @@ def _evaluation_loop(raw, events, model_hmm, step_length):
|
|
return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
|
|
return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
|
|
|
|
|
|
|
|
|
|
-def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
|
|
|
|
|
|
+def validation(raw_val, event_id, model,
|
|
|
|
+ state_change_threshold=0.8,
|
|
|
|
+ step_length=1.,
|
|
|
|
+ event_trial_length=5.):
|
|
"""模型验证接口,使用指定数据进行验证,绘制ersd map
|
|
"""模型验证接口,使用指定数据进行验证,绘制ersd map
|
|
Args:
|
|
Args:
|
|
raw (mne.io.Raw)
|
|
raw (mne.io.Raw)
|
|
@@ -126,6 +129,7 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
|
|
model: validate existing model,
|
|
model: validate existing model,
|
|
state_change_threshold (float): default 0.8
|
|
state_change_threshold (float): default 0.8
|
|
step_length (float): batch data step length, default 1. (s)
|
|
step_length (float): batch data step length, default 1. (s)
|
|
|
|
+ event_trial_length (float):
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
None
|
|
None
|
|
@@ -149,7 +153,11 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
|
|
model_hmm = controller.real_feedback_model
|
|
model_hmm = controller.real_feedback_model
|
|
|
|
|
|
# run with and without hmm
|
|
# run with and without hmm
|
|
- fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, events_val, model_hmm, step_length)
|
|
|
|
|
|
+ fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
|
|
|
|
+ events_val,
|
|
|
|
+ model_hmm,
|
|
|
|
+ step_length,
|
|
|
|
+ event_trial_length=event_trial_length)
|
|
|
|
|
|
return metric_hmm, metric_naive, fig_erds, fig_pred
|
|
return metric_hmm, metric_naive, fig_erds, fig_pred
|
|
|
|
|
|
@@ -163,10 +171,13 @@ def _construct_model_event(decision_seq, fs):
|
|
return np.array(events)
|
|
return np.array(events)
|
|
|
|
|
|
|
|
|
|
-def _event_to_stim_channel(events, time_length):
|
|
|
|
|
|
+def _event_to_stim_channel(events, time_length, trial_length=None):
|
|
x = np.zeros(time_length)
|
|
x = np.zeros(time_length)
|
|
for i in range(0, len(events) - 1):
|
|
for i in range(0, len(events) - 1):
|
|
- x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
|
|
|
|
|
|
+ if trial_length is not None:
|
|
|
|
+ x[events[i, 0]: events[i, 0] + trial_length] = events[i, 2]
|
|
|
|
+ else:
|
|
|
|
+ x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
|
|
@@ -185,15 +196,21 @@ if __name__ == '__main__':
|
|
event_id[f] = neo.FINGERMODEL_IDS[f]
|
|
event_id[f] = neo.FINGERMODEL_IDS[f]
|
|
|
|
|
|
# preprocess raw
|
|
# preprocess raw
|
|
- trial_duration = config_info['buffer_length']
|
|
|
|
- raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1], upsampled_epoch_length=trial_duration)
|
|
|
|
|
|
+ trial_time = 5.
|
|
|
|
+ raw = neo.raw_preprocessing(data_dir, sessions,
|
|
|
|
+ unify_label=True,
|
|
|
|
+ ori_epoch_length=trial_time,
|
|
|
|
+ mov_trial_ind=[2],
|
|
|
|
+ rest_trial_ind=[1],
|
|
|
|
+ upsampled_epoch_length=None)
|
|
|
|
|
|
# do validations
|
|
# do validations
|
|
metric_hmm, metric_naive, fig_erds, fig_pred = validation(raw,
|
|
metric_hmm, metric_naive, fig_erds, fig_pred = validation(raw,
|
|
event_id,
|
|
event_id,
|
|
model=model_path,
|
|
model=model_path,
|
|
state_change_threshold=args.state_change_threshold,
|
|
state_change_threshold=args.state_change_threshold,
|
|
- step_length=config_info['buffer_length'])
|
|
|
|
|
|
+ step_length=config_info['buffer_length'],
|
|
|
|
+ event_trial_length=trial_time)
|
|
fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
|
|
fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|
|
logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
|
|
logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
|