''' 模型模拟在线测试脚本 数据两折分割,1折训练模型,1折按照在线模式测试:decison AUC + event f1-score ''' import numpy as np import matplotlib.pyplot as plt import mne import yaml import os import joblib from scipy import stats from dataloaders import neo import bci_core.online as online import bci_core.utils as bci_utils import bci_core.viz as bci_viz class DataGenerator: def __init__(self, fs, X): self.fs = int(fs) self.X = X def get_data_batch(self, current_index): # return 1s batch # create mne object data = self.X[:, current_index - self.fs:current_index].copy() return self.fs, [], data def loop(self, step_size=0.1): step = int(step_size * self.fs) for i in range(self.fs, self.X.shape[1] + 1, step): yield i / self.fs, self.get_data_batch(i) def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8): """模型验证接口,使用指定数据进行训练+验证,绘制ersd map Args: raw (mne.io.Raw) model_type (string): type of model to train, baseline or riemann event_id (dict) model: validate existing model, state_change_threshold (float): default 0.8 Returns: None """ fs = raw_val.info['sfreq'] events_val, _ = mne.events_from_annotations(raw_val, event_id) # plot ersd map fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0) events_val = neo.reconstruct_events(events_val, fs, finger_model=None, rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], use_original_label=True) if model_type == 'baseline': hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold) else: raise NotImplementedError controller = online.Controller(0, None) controller.set_real_feedback_model(hmm_model) # validate with the second half val_data = raw_val.get_data() data_gen = DataGenerator(fs, val_data) rets = [] for time, data in data_gen.loop(): cls = controller.decision(data) rets.append((time, cls)) events_pred = _construct_model_event(rets, fs) precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs) stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times)) stim_true = _event_to_stim_channel(events_val, len(raw_val.times)) corr, p = stats.pearsonr(stim_pred, stim_true) fig_pred, ax = plt.subplots(1, 1) ax.plot(raw_val.times, stim_pred, label='pred') ax.plot(raw_val.times, stim_true, label='true') ax.legend() return (precision, recall, f_beta_score, corr), fig_erds, fig_pred def _construct_model_event(decision_seq, fs): events = [] for i in decision_seq: time, cls = i if cls >= 0: events.append([int(time * fs), 0, cls]) return np.array(events) def _event_to_stim_channel(events, time_length): x = np.zeros(time_length) for i in range(0, len(events) - 1): x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2] return x if __name__ == '__main__': # TODO: argparse subj_name = 'ylj' model_type = 'baseline' # TODO: load subject config data_dir = f'./data/{subj_name}/val/' model_path = f'./static/models/{subj_name}/scis.pkl' info = yaml.safe_load(os.path.join(data_dir, 'info.yml')) sessions = info['sessions'] event_id = {'rest': 0} for f in sessions.keys(): event_id[f] = neo.FINGERMODEL_IDS[f] # preprocess raw raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5) # load model model = joblib.load(model_path) model_type, events = bci_utils.parse_model_type(model_path) metrics, fig_erds, fig_pred = validation(raw, model_type, event_id, model=model, state_change_threshold=0.8) fig_erds.savefig(os.path.join(data_dir, 'erds.pdf')) fig_pred.savefig(os.path.join(data_dir, 'pred.pdf')) print(metrics)