123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- '''
- 模型模拟在线测试脚本
- 数据两折分割,1折训练模型,1折按照在线模式测试:decison AUC + event f1-score
- '''
- import numpy as np
- import matplotlib.pyplot as plt
- import mne
- import yaml
- import os
- import joblib
- from scipy import stats
- from dataloaders import neo
- import bci_core.online as online
- import bci_core.utils as bci_utils
- import bci_core.viz as bci_viz
- class DataGenerator:
- def __init__(self, fs, X):
- self.fs = int(fs)
- self.X = X
- def get_data_batch(self, current_index):
- # return 1s batch
- # create mne object
- data = self.X[:, current_index - self.fs:current_index].copy()
- return self.fs, [], data
- def loop(self, step_size=0.1):
- step = int(step_size * self.fs)
- for i in range(self.fs, self.X.shape[1] + 1, step):
- yield i / self.fs, self.get_data_batch(i)
- def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8):
- """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
- Args:
- raw (mne.io.Raw)
- model_type (string): type of model to train, baseline or riemann
- event_id (dict)
- model: validate existing model,
- state_change_threshold (float): default 0.8
- Returns:
- None
- """
- fs = raw_val.info['sfreq']
- events_val, _ = mne.events_from_annotations(raw_val, event_id)
- # plot ersd map
- fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0)
-
- events_val = neo.reconstruct_events(events_val,
- fs,
- finger_model=None,
- rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
- mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
- use_original_label=True)
-
- if model_type == 'baseline':
- hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold)
- else:
- raise NotImplementedError
- controller = online.Controller(0, None)
- controller.set_real_feedback_model(hmm_model)
- # validate with the second half
- val_data = raw_val.get_data()
- data_gen = DataGenerator(fs, val_data)
- rets = []
- for time, data in data_gen.loop():
- cls = controller.decision(data)
- rets.append((time, cls))
- events_pred = _construct_model_event(rets, fs)
- precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
- stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
- stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
- corr, p = stats.pearsonr(stim_pred, stim_true)
- fig_pred, ax = plt.subplots(1, 1)
- ax.plot(raw_val.times, stim_pred, label='pred')
- ax.plot(raw_val.times, stim_true, label='true')
- ax.legend()
- return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
- def _construct_model_event(decision_seq, fs):
- events = []
- for i in decision_seq:
- time, cls = i
- if cls >= 0:
- events.append([int(time * fs), 0, cls])
- return np.array(events)
- def _event_to_stim_channel(events, time_length):
- x = np.zeros(time_length)
- for i in range(0, len(events) - 1):
- x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
- return x
- if __name__ == '__main__':
- # TODO: argparse
- subj_name = 'ylj'
- model_type = 'baseline'
- # TODO: load subject config
- data_dir = f'./data/{subj_name}/val/'
- model_path = f'./static/models/{subj_name}/scis.pkl'
- info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
- sessions = info['sessions']
- event_id = {'rest': 0}
- for f in sessions.keys():
- event_id[f] = neo.FINGERMODEL_IDS[f]
-
- # preprocess raw
- raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5)
- # load model
- model = joblib.load(model_path)
- model_type, events = bci_utils.parse_model_type(model_path)
- metrics, fig_erds, fig_pred = validation(raw,
- model_type,
- event_id,
- model=model,
- state_change_threshold=0.8)
- fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
- fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
- print(metrics)
|