''' 模型模拟在线测试脚本 在线模式测试:event f1-score and decision trace ''' import numpy as np import matplotlib.pyplot as plt import mne import yaml import os import argparse import logging from sklearn.metrics import accuracy_score from dataloaders import neo import bci_core.online as online import bci_core.utils as bci_utils from settings.config import settings logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--state-change-threshold', '-scth', dest='state_change_threshold', help='Threshold for HMM state change', default=0.75, type=float ) parser.add_argument( '--state-trans-prob', '-stp', dest='state_trans_prob', help='Transition probability for HMM state change', default=0.8, type=float ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() class DataGenerator: def __init__(self, fs, X, epoch_step=1.): self.fs = int(fs) self.X = X self.epoch_step = epoch_step def get_data_batch(self, current_index): # return epoch_step length batch # create mne object ind = int(self.epoch_step * self.fs) data = self.X[:, current_index - ind:current_index].copy() return self.fs, [], data def loop(self, step_size=0.1): step = int(step_size * self.fs) for i in range(self.fs, self.X.shape[1] + 1, step): yield i / self.fs, self.get_data_batch(i) def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length): val_data = raw.get_data() fs = raw.info['sfreq'] data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length) decision_with_hmm = [] decision_without_hmm = [] probs = [] for time, data in data_gen.loop(step_length): step_p, cls = model_hmm.viterbi(data, return_step_p=True) if cls >=0: cls = model_hmm.model.classes_[cls] decision_with_hmm.append((time, cls)) # map to unified label decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)])) probs.append((time, model_hmm.probability)) probs = np.array(probs) events_pred = _construct_model_event(decision_with_hmm, fs) events_pred_naive = _construct_model_event(decision_without_hmm, fs) p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs) p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs) stim_true = _event_to_stim_channel(events, len(raw.times), trial_length=int(event_trial_length * fs)) stim_pred = _event_to_stim_channel(events_pred, len(raw.times)) stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times)) accu_hmm = accuracy_score(stim_true, stim_pred) accu_naive = accuracy_score(stim_true, stim_pred_naive) fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False) ax[0].set_title('pred (naive)') ax[0].plot(raw.times, stim_pred_naive) ax[1].set_title('pred') ax[1].plot(raw.times, stim_pred) ax[2].set_title('true') ax[2].plot(raw.times, stim_true) ax[3].set_title('prob') ax[3].plot(probs[:, 0], probs[:, 1]) ax[3].set_ylim([0, 1]) return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive) def simulation(raw_val, event_id, model, epoch_length=1., step_length=0.1, event_trial_length=5.): """模型验证接口,使用指定数据进行验证,绘制ersd map Args: raw (mne.io.Raw) event_id (dict) model: validate existing model, epoch_length (float): batch data length, default 1 (s) step_length (float): data step length, default 0.1 (s) event_trial_length (float): Returns: None """ fs = raw_val.info['sfreq'] events_val, _ = mne.events_from_annotations(raw_val, event_id) # run with and without hmm fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, events_val, model, epoch_length, step_length, event_trial_length=event_trial_length) return metric_hmm, metric_naive, fig_pred def _construct_model_event(decision_seq, fs): events = [] for i in decision_seq: time, cls = i if cls >= 0: events.append([int(time * fs), 0, cls]) return np.array(events) def _event_to_stim_channel(events, time_length, trial_length=None): x = np.zeros(time_length) for i in range(0, len(events) - 1): if trial_length is not None: x[events[i, 0]: events[i, 0] + trial_length] = events[i, 2] else: x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2] return x if __name__ == '__main__': args = parse_args() subj_name = args.subj data_dir = f'./data/{subj_name}/' model_path = f'./static/models/{subj_name}/{args.model_filename}' with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] # preprocess raw trial_time = 5. raw, event_id = neo.raw_loader(data_dir, sessions, ori_epoch_length=trial_time, upsampled_epoch_length=None) # load model input_kwargs = { 'state_trans_prob': args.state_trans_prob, 'state_change_threshold': args.state_change_threshold } model_hmm = online.model_loader(model_path, **input_kwargs) # do validations metric_hmm, metric_naive, fig_pred = simulation(raw, event_id, model=model_hmm, epoch_length=config_info['buffer_length'], step_length=config_info['buffer_length'], event_trial_length=trial_time) fig_pred.savefig(os.path.join(data_dir, 'pred.pdf')) logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}') logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')