''' 模型模拟在线测试脚本 在线模式测试:event f1-score and decision trace ''' import numpy as np import matplotlib.pyplot as plt import mne import yaml import os import logging from scipy import stats from dataloaders import neo import bci_core.online as online import bci_core.utils as bci_utils import bci_core.viz as bci_viz from settings.config import settings logging.basicConfig(level=logging.INFO) config_info = settings.CONFIG_INFO class DataGenerator: def __init__(self, fs, X, epoch_step=1.): self.fs = int(fs) self.X = X self.epoch_step = epoch_step def get_data_batch(self, current_index): # return epoch_step length batch # create mne object ind = int(self.epoch_step * self.fs) data = self.X[:, current_index - ind:current_index].copy() return self.fs, [], data def loop(self, step_size=0.1): step = int(step_size * self.fs) for i in range(self.fs, self.X.shape[1] + 1, step): yield i / self.fs, self.get_data_batch(i) def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.): """模型验证接口,使用指定数据进行训练+验证,绘制ersd map Args: raw (mne.io.Raw) event_id (dict) model: validate existing model, state_change_threshold (float): default 0.8 step_length (float): batch data step length, default 1. (s) Returns: None """ fs = raw_val.info['sfreq'] events_val, _ = mne.events_from_annotations(raw_val, event_id) # plot ersd map fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0) events_val = neo.reconstruct_events(events_val, fs, finger_model=None, rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], use_original_label=True) controller = online.Controller(0, model, state_change_threshold=state_change_threshold) # validate with the second half val_data = raw_val.get_data() data_gen = DataGenerator(fs, val_data, epoch_step=step_length) rets = [] for time, data in data_gen.loop(): cls = controller.decision(data) rets.append((time, cls)) events_pred = _construct_model_event(rets, fs) precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs) stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times)) stim_true = _event_to_stim_channel(events_val, len(raw_val.times)) corr, p = stats.pearsonr(stim_pred, stim_true) fig_pred, ax = plt.subplots(1, 1) ax.plot(raw_val.times, stim_pred, label='pred') ax.plot(raw_val.times, stim_true, label='true') ax.legend() return (precision, recall, f_beta_score, corr), fig_erds, fig_pred def _construct_model_event(decision_seq, fs): events = [] for i in decision_seq: time, cls = i if cls >= 0: events.append([int(time * fs), 0, cls]) return np.array(events) def _event_to_stim_channel(events, time_length): x = np.zeros(time_length) for i in range(0, len(events) - 1): x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2] return x if __name__ == '__main__': # TODO: argparse subj_name = 'XW01' # TODO: load subject config data_dir = f'./data/{subj_name}/' model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-16-43-23.pkl' with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] event_id = {'rest': 0} for f in sessions.keys(): event_id[f] = neo.FINGERMODEL_IDS[f] # preprocess raw raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1]) # do validations metrics, fig_erds, fig_pred = validation(raw, event_id, model=model_path, state_change_threshold=0.75, step_length=config_info['buffer_length']) fig_erds.savefig(os.path.join(data_dir, 'erds.pdf')) fig_pred.savefig(os.path.join(data_dir, 'pred.pdf')) logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')