|
@@ -16,7 +16,7 @@ import bci_core.viz as bci_viz
|
|
|
from settings.config import settings
|
|
|
|
|
|
|
|
|
-logging.basicConfig(level=logging.INFO)
|
|
|
+logging.basicConfig(level=logging.DEBUG)
|
|
|
config_info = settings.CONFIG_INFO
|
|
|
|
|
|
|
|
@@ -71,19 +71,32 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
|
|
|
# validate with the second half
|
|
|
val_data = raw_val.get_data()
|
|
|
data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
|
|
|
- rets = []
|
|
|
+
|
|
|
+ decisions = []
|
|
|
+ probs = []
|
|
|
for time, data in data_gen.loop():
|
|
|
cls = controller.decision(data)
|
|
|
- rets.append((time, cls))
|
|
|
- events_pred = _construct_model_event(rets, fs)
|
|
|
+ decisions.append((time, cls))
|
|
|
+ probs.append((time, controller.real_feedback_model.probability))
|
|
|
+ probs = np.array(probs)
|
|
|
+
|
|
|
+ events_pred = _construct_model_event(decisions, fs)
|
|
|
+
|
|
|
precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
|
|
|
+
|
|
|
stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
|
|
|
stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
|
|
|
- corr, p = stats.pearsonr(stim_pred, stim_true)
|
|
|
- fig_pred, ax = plt.subplots(1, 1)
|
|
|
- ax.plot(raw_val.times, stim_pred, label='pred')
|
|
|
- ax.plot(raw_val.times, stim_true, label='true')
|
|
|
- ax.legend()
|
|
|
+
|
|
|
+ corr, _ = stats.pearsonr(stim_pred, stim_true)
|
|
|
+
|
|
|
+ fig_pred, ax = plt.subplots(1, 1, sharex=True, sharey=False)
|
|
|
+ ax[0].set_title('pred')
|
|
|
+ ax[0].plot(raw_val.times, stim_pred)
|
|
|
+ ax[1].set_title('true')
|
|
|
+ ax[1].plot(raw_val.times, stim_true)
|
|
|
+ ax[2].set_title('prob')
|
|
|
+ ax[2].plot(probs[:, 0], probs[:, 1])
|
|
|
+ ax[2].set_ylim([0, 1])
|
|
|
|
|
|
return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
|
|
|
|
|
@@ -110,7 +123,7 @@ if __name__ == '__main__':
|
|
|
# TODO: load subject config
|
|
|
|
|
|
data_dir = f'./data/{subj_name}/'
|
|
|
- model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-16-43-23.pkl'
|
|
|
+ model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-21-23-15.pkl'
|
|
|
with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
|
|
|
info = yaml.safe_load(f)
|
|
|
sessions = info['sessions']
|