''' 模型模拟在线测试脚本 在线模式测试:event f1-score and decision trace ''' import numpy as np import matplotlib.pyplot as plt import mne import yaml import os import argparse import logging from sklearn.metrics import accuracy_score from dataloaders import neo import bci_core.online as online import bci_core.utils as bci_utils import bci_core.viz as bci_viz from settings.config import settings logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--state-change-threshold', '-scth', dest='state_change_threshold', help='Threshold for HMM state change', default=0.75, type=float ) parser.add_argument( '--state-trans-prob', '-stp', dest='state_trans_prob', help='Transition probability for HMM state change', default=0.8, type=float ) parser.add_argument( '--momentum', help='Probability update momentum', default=0.5, type=float ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() class DataGenerator: def __init__(self, fs, X, epoch_step=1.): self.fs = int(fs) self.X = X self.epoch_step = epoch_step def get_data_batch(self, current_index): # return epoch_step length batch # create mne object ind = int(self.epoch_step * self.fs) data = self.X[:, current_index - ind:current_index].copy() return self.fs, [], data def loop(self, step_size=0.1): step = int(step_size * self.fs) for i in range(int(self.epoch_step * self.fs), self.X.shape[1] + 1, step): yield i / self.fs, self.get_data_batch(i) @property def time_range(self): return self.epoch_step, self.X.shape[1] / self.fs def time_steps(self, step_size=0.1): step = int(step_size * self.fs) return len(list(range(int(self.epoch_step * self.fs), self.X.shape[1] + 1, step))) def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length): val_data = raw.get_data() fs = raw.info['sfreq'] data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length) # events -> 1 / step_length events[:, 0] = (events[:, 0] / fs / step_length).astype(np.int32) decision_with_hmm = [] decision_without_hmm = [] probs = [] probs_naive = [] for time, (fs, event, data) in data_gen.loop(step_length): step_p, cls = model_hmm.viterbi(fs, data, return_step_p=True) if cls >=0: cls = model_hmm.model.classes_[cls] decision_with_hmm.append((time, cls)) # map to unified label decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)])) probs.append(model_hmm.probability) probs_naive.append(step_p) probs = np.array(probs) probs_naive = np.array(probs_naive) events_pred = _construct_model_event(decision_with_hmm, 1 / step_length) events_pred_naive = _construct_model_event(decision_without_hmm, 1 / step_length) p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs) p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs) time_steps = data_gen.time_steps(step_length) start_ind = int(data_gen.time_range[0] / step_length) stim_true = bci_utils.event_to_stim_channel(events, time_steps, trial_length=int(event_trial_length / step_length), start_ind=start_ind) stim_pred = bci_utils.event_to_stim_channel(events_pred, time_steps, start_ind=start_ind) stim_pred_naive = bci_utils.event_to_stim_channel(events_pred_naive, time_steps, start_ind=start_ind) accu_hmm = accuracy_score(stim_true, stim_pred) accu_naive = accuracy_score(stim_true, stim_pred_naive) # hmm plot fig_hmm, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8)) axes[0].set_title('True states') bci_viz.plot_states(data_gen.time_range, stim_true, ax=axes[0]) axes[1].set_title('State sequence') bci_viz.plot_states(data_gen.time_range, stim_pred, ax=axes[1]) for i, ax in enumerate(axes[2:]): bci_viz.plot_state_prob_with_cue(data_gen.time_range, stim_true, probs[:, i], ax=ax) fig_hmm.suptitle('With HMM') # naive plot fig_naive, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, sharey=True, figsize=(10, 8)) axes[0].set_title('True states') bci_viz.plot_states(data_gen.time_range, stim_true, ax=axes[0]) axes[1].set_title('State sequence') bci_viz.plot_states(data_gen.time_range, stim_pred_naive, ax=axes[1]) for i, ax in enumerate(axes[2:]): bci_viz.plot_state_prob_with_cue(data_gen.time_range, stim_true, probs_naive[:, i], ax=ax) fig_naive.suptitle('Naive') return (fig_hmm, fig_naive), (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive) def simulation(raw_val, event_id, model, epoch_length=1., step_length=0.1, event_trial_length=5.): """模型验证接口,使用指定数据进行验证,绘制ersd map Args: raw (mne.io.Raw) event_id (dict) model: validate existing model, epoch_length (float): batch data length, default 1 (s) step_length (float): data step length, default 0.1 (s) event_trial_length (float): Returns: None """ fs = raw_val.info['sfreq'] events_val, _ = mne.events_from_annotations(raw_val, event_id) # run with and without hmm fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, events_val, model, epoch_length, step_length, event_trial_length=event_trial_length) return metric_hmm, metric_naive, fig_pred def _construct_model_event(decision_seq, fs, start_cond=0): def _filter_seq(decision_seq): new_seq = [(decision_seq[0][0], start_cond)] for i in range(1, len(decision_seq)): if decision_seq[i][1] == -1: new_seq.append((decision_seq[i][0], new_seq[-1][1])) else: new_seq.append(decision_seq[i]) return new_seq decision_seq = _filter_seq(decision_seq) last_state = decision_seq[0][1] events = [(int(decision_seq[0][0] * fs), 0, last_state)] for i in range(1, len(decision_seq)): time, label = decision_seq[i] if label != last_state: last_state = label events.append([int(time * fs), 0, label]) return np.array(events) if __name__ == '__main__': args = parse_args() subj_name = args.subj data_dir = os.path.join(settings.DATA_PATH, subj_name) model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename) with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] # preprocess raw trial_time = 5. raw, event_id = neo.raw_loader(data_dir, sessions, reref_method=config_info['reref'], ori_epoch_length=trial_time, upsampled_epoch_length=None) # load model input_kwargs = { 'state_trans_prob': args.state_trans_prob, 'state_change_threshold': args.state_change_threshold, 'momentum': args.momentum } model_hmm = online.model_loader(model_path, **input_kwargs) # do online simulation metric_hmm, metric_naive, fig_pred = simulation(raw, event_id, model=model_hmm, epoch_length=config_info['buffer_length'], step_length=0.1, event_trial_length=trial_time) fig_pred[0].savefig(os.path.join(data_dir, 'pred_hmm.pdf')) fig_pred[1].savefig(os.path.join(data_dir, 'pred_naive.pdf')) logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}') logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}') plt.show()