123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- '''
- 模型模拟在线测试脚本
- 在线模式测试:event f1-score and decision trace
- '''
- import numpy as np
- import matplotlib.pyplot as plt
- import mne
- import yaml
- import os
- import argparse
- import logging
- from sklearn.metrics import accuracy_score
- from dataloaders import neo
- import bci_core.online as online
- import bci_core.utils as bci_utils
- from settings.config import settings
- logging.basicConfig(level=logging.DEBUG)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--state-change-threshold',
- '-scth',
- dest='state_change_threshold',
- help='Threshold for HMM state change',
- default=0.75,
- type=float
- )
- parser.add_argument(
- '--state-trans-prob',
- '-stp',
- dest='state_trans_prob',
- help='Transition probability for HMM state change',
- default=0.8,
- type=float
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model filename',
- default=None,
- type=str
- )
- return parser.parse_args()
- class DataGenerator:
- def __init__(self, fs, X, epoch_step=1.):
- self.fs = int(fs)
- self.X = X
- self.epoch_step = epoch_step
- def get_data_batch(self, current_index):
- # return epoch_step length batch
- # create mne object
- ind = int(self.epoch_step * self.fs)
- data = self.X[:, current_index - ind:current_index].copy()
- return self.fs, [], data
- def loop(self, step_size=0.1):
- step = int(step_size * self.fs)
- for i in range(self.fs, self.X.shape[1] + 1, step):
- yield i / self.fs, self.get_data_batch(i)
- def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
- val_data = raw.get_data()
- fs = raw.info['sfreq']
- data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
- decision_with_hmm = []
- decision_without_hmm = []
- probs = []
- for time, data in data_gen.loop(step_length):
- step_p, cls = model_hmm.viterbi(data, return_step_p=True)
- if cls >=0:
- cls = model_hmm.model.classes_[cls]
- decision_with_hmm.append((time, cls)) # map to unified label
- decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
- probs.append((time, model_hmm.probability))
- probs = np.array(probs)
-
- events_pred = _construct_model_event(decision_with_hmm, fs)
- events_pred_naive = _construct_model_event(decision_without_hmm, fs)
- p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs)
- p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
- stim_true = _event_to_stim_channel(events, len(raw.times), trial_length=int(event_trial_length * fs))
- stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
- stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
- accu_hmm = accuracy_score(stim_true, stim_pred)
- accu_naive = accuracy_score(stim_true, stim_pred_naive)
- fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False)
- ax[0].set_title('pred (naive)')
- ax[0].plot(raw.times, stim_pred_naive)
- ax[1].set_title('pred')
- ax[1].plot(raw.times, stim_pred)
- ax[2].set_title('true')
- ax[2].plot(raw.times, stim_true)
- ax[3].set_title('prob')
- ax[3].plot(probs[:, 0], probs[:, 1])
- ax[3].set_ylim([0, 1])
- return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
- def simulation(raw_val, event_id, model,
- state_trans_prob=0.8,
- state_change_threshold=0.8,
- step_length=1.,
- event_trial_length=5.):
- """模型验证接口,使用指定数据进行验证,绘制ersd map
- Args:
- raw (mne.io.Raw)
- event_id (dict)
- model: validate existing model,
- state_change_threshold (float): default 0.8
- step_length (float): batch data step length, default 1. (s)
- event_trial_length (float):
- Returns:
- None
- """
- fs = raw_val.info['sfreq']
- events_val, _ = mne.events_from_annotations(raw_val, event_id)
-
- events_val = neo.reconstruct_events(events_val,
- fs,
- finger_model=None,
- rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
- mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
- use_original_label=True)
-
-
- controller = online.Controller(0, model,
- state_trans_prob=state_trans_prob,
- state_change_threshold=state_change_threshold)
- model_hmm = controller.real_feedback_model
- # run with and without hmm
- fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
- events_val,
- model_hmm,
- step_length,
- event_trial_length=event_trial_length)
- return metric_hmm, metric_naive, fig_pred
- def _construct_model_event(decision_seq, fs):
- events = []
- for i in decision_seq:
- time, cls = i
- if cls >= 0:
- events.append([int(time * fs), 0, cls])
- return np.array(events)
- def _event_to_stim_channel(events, time_length, trial_length=None):
- x = np.zeros(time_length)
- for i in range(0, len(events) - 1):
- if trial_length is not None:
- x[events[i, 0]: events[i, 0] + trial_length] = events[i, 2]
- else:
- x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
- return x
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- data_dir = f'./data/{subj_name}/'
-
- model_path = f'./static/models/{subj_name}/{args.model_filename}'
- with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
- event_id = {'rest': 0}
- for f in sessions.keys():
- event_id[f] = neo.FINGERMODEL_IDS[f]
-
- # preprocess raw
- trial_time = 5.
- raw = neo.raw_preprocessing(data_dir, sessions,
- unify_label=True,
- ori_epoch_length=trial_time,
- mov_trial_ind=[2],
- rest_trial_ind=[1],
- upsampled_epoch_length=None)
- # do validations
- metric_hmm, metric_naive, fig_pred = simulation(raw,
- event_id,
- model=model_path,
- state_trans_prob=args.state_trans_prob,
- state_change_threshold=args.state_change_threshold,
- step_length=config_info['buffer_length'],
- event_trial_length=trial_time)
- fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
- logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
- logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')
|