|
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
|
import mne
|
|
|
import yaml
|
|
|
import os
|
|
|
-import joblib
|
|
|
+import logging
|
|
|
from scipy import stats
|
|
|
from dataloaders import neo
|
|
|
import bci_core.online as online
|
|
@@ -15,6 +15,9 @@ import bci_core.utils as bci_utils
|
|
|
import bci_core.viz as bci_viz
|
|
|
|
|
|
|
|
|
+logging.basicConfig(level=logging.INFO)
|
|
|
+
|
|
|
+
|
|
|
class DataGenerator:
|
|
|
def __init__(self, fs, X):
|
|
|
self.fs = int(fs)
|
|
@@ -32,11 +35,10 @@ class DataGenerator:
|
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
|
|
|
|
|
|
|
-def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8):
|
|
|
+def validation(raw_val, event_id, model, state_change_threshold=0.8):
|
|
|
"""模型验证接口,使用指定数据进行训练+验证,绘制ersd map
|
|
|
Args:
|
|
|
raw (mne.io.Raw)
|
|
|
- model_type (string): type of model to train, baseline or riemann
|
|
|
event_id (dict)
|
|
|
model: validate existing model,
|
|
|
state_change_threshold (float): default 0.8
|
|
@@ -58,12 +60,8 @@ def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8)
|
|
|
mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
|
|
|
use_original_label=True)
|
|
|
|
|
|
- if model_type == 'baseline':
|
|
|
- hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold)
|
|
|
- else:
|
|
|
- raise NotImplementedError
|
|
|
- controller = online.Controller(0, None)
|
|
|
- controller.set_real_feedback_model(hmm_model)
|
|
|
+
|
|
|
+ controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
|
|
|
|
|
|
# validate with the second half
|
|
|
val_data = raw_val.get_data()
|
|
@@ -107,27 +105,24 @@ if __name__ == '__main__':
|
|
|
model_type = 'baseline'
|
|
|
# TODO: load subject config
|
|
|
|
|
|
- data_dir = f'./data/{subj_name}/val/'
|
|
|
- model_path = f'./static/models/{subj_name}/scis.pkl'
|
|
|
+ data_dir = f'./data/{subj_name}/train/'
|
|
|
+ model_path = f'./static/models/{subj_name}/baseline_rest+cylinder_11-15-2023-21-34-41.pkl'
|
|
|
|
|
|
- info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
|
|
|
+ with open(os.path.join(data_dir, 'info.yml'), 'r') as f:
|
|
|
+ info = yaml.safe_load(f)
|
|
|
sessions = info['sessions']
|
|
|
event_id = {'rest': 0}
|
|
|
for f in sessions.keys():
|
|
|
event_id[f] = neo.FINGERMODEL_IDS[f]
|
|
|
|
|
|
# preprocess raw
|
|
|
- raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5)
|
|
|
-
|
|
|
- # load model
|
|
|
- model = joblib.load(model_path)
|
|
|
- model_type, events = bci_utils.parse_model_type(model_path)
|
|
|
+ raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7])
|
|
|
|
|
|
+ # do validations
|
|
|
metrics, fig_erds, fig_pred = validation(raw,
|
|
|
- model_type,
|
|
|
event_id,
|
|
|
- model=model,
|
|
|
- state_change_threshold=0.8)
|
|
|
+ model=model_path,
|
|
|
+ state_change_threshold=0.95)
|
|
|
fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
|
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|
|
|
- print(metrics)
|
|
|
+ logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')
|