''' 模型模拟在线测试脚本 在线模式测试:event f1-score and decision trace ''' import numpy as np import matplotlib.pyplot as plt import mne import yaml import os import argparse import logging from sklearn.metrics import accuracy_score from dataloaders import neo import bci_core.online as online import bci_core.utils as bci_utils import bci_core.viz as bci_viz from settings.config import settings logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--state-change-threshold', '-scth', dest='state_change_threshold', help='Threshold for HMM state change', default=0.75, type=float ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() class DataGenerator: def __init__(self, fs, X, epoch_step=1.): self.fs = int(fs) self.X = X self.epoch_step = epoch_step def get_data_batch(self, current_index): # return epoch_step length batch # create mne object ind = int(self.epoch_step * self.fs) data = self.X[:, current_index - ind:current_index].copy() return self.fs, [], data def loop(self, step_size=0.1): step = int(step_size * self.fs) for i in range(self.fs, self.X.shape[1] + 1, step): yield i / self.fs, self.get_data_batch(i) def _evaluation_loop(raw, events, model_hmm, step_length): val_data = raw.get_data() fs = raw.info['sfreq'] data_gen = DataGenerator(fs, val_data, epoch_step=step_length) decision_with_hmm = [] decision_without_hmm = [] probs = [] for time, data in data_gen.loop(): step_p, cls = model_hmm.viterbi(data, return_step_p=True) if cls >=0: cls = model_hmm.model.classes_[cls] decision_with_hmm.append((time, cls)) # map to unified label decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)])) probs.append((time, model_hmm.probability)) probs = np.array(probs) events_pred = _construct_model_event(decision_with_hmm, fs) events_pred_naive = _construct_model_event(decision_without_hmm, fs) p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs) p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs) stim_true = _event_to_stim_channel(events, len(raw.times)) stim_pred = _event_to_stim_channel(events_pred, len(raw.times)) stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times)) # TODO: auc accu_hmm = accuracy_score(stim_true, stim_pred) accu_naive = accuracy_score(stim_true, stim_pred_naive) fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False) ax[0].set_title('pred (naive)') ax[0].plot(raw.times, stim_pred_naive) ax[1].set_title('pred') ax[1].plot(raw.times, stim_pred) ax[2].set_title('true') ax[2].plot(raw.times, stim_true) ax[3].set_title('prob') ax[3].plot(probs[:, 0], probs[:, 1]) ax[3].set_ylim([0, 1]) return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive) def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.): """模型验证接口,使用指定数据进行验证,绘制ersd map Args: raw (mne.io.Raw) event_id (dict) model: validate existing model, state_change_threshold (float): default 0.8 step_length (float): batch data step length, default 1. (s) Returns: None """ fs = raw_val.info['sfreq'] events_val, _ = mne.events_from_annotations(raw_val, event_id) # plot ersd map fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0) events_val = neo.reconstruct_events(events_val, fs, finger_model=None, rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], use_original_label=True) controller = online.Controller(0, model, state_change_threshold=state_change_threshold) model_hmm = controller.real_feedback_model # run with and without hmm fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, events_val, model_hmm, step_length) return metric_hmm, metric_naive, fig_erds, fig_pred def _construct_model_event(decision_seq, fs): events = [] for i in decision_seq: time, cls = i if cls >= 0: events.append([int(time * fs), 0, cls]) return np.array(events) def _event_to_stim_channel(events, time_length): x = np.zeros(time_length) for i in range(0, len(events) - 1): x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2] return x if __name__ == '__main__': args = parse_args() subj_name = args.subj data_dir = f'./data/{subj_name}/' model_path = f'./static/models/{subj_name}/{args.model_filename}' with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] event_id = {'rest': 0} for f in sessions.keys(): event_id[f] = neo.FINGERMODEL_IDS[f] # preprocess raw trial_duration = config_info['buffer_length'] raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1], upsampled_epoch_length=trial_duration) # do validations metric_hmm, metric_naive, fig_erds, fig_pred = validation(raw, event_id, model=model_path, state_change_threshold=args.state_change_threshold, step_length=config_info['buffer_length']) fig_erds.savefig(os.path.join(data_dir, 'erds.pdf')) fig_pred.savefig(os.path.join(data_dir, 'pred.pdf')) logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}') logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')