123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- '''
- 模型模拟在线测试脚本
- 在线模式测试:event f1-score and decision trace
- '''
- import numpy as np
- import matplotlib.pyplot as plt
- import mne
- import yaml
- import os
- import argparse
- import logging
- from sklearn.metrics import accuracy_score
- from dataloaders import neo
- import bci_core.online as online
- import bci_core.utils as bci_utils
- from settings.config import settings
- logging.basicConfig(level=logging.DEBUG)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--state-change-threshold',
- '-scth',
- dest='state_change_threshold',
- help='Threshold for HMM state change',
- default=0.75,
- type=float
- )
- parser.add_argument(
- '--state-trans-prob',
- '-stp',
- dest='state_trans_prob',
- help='Transition probability for HMM state change',
- default=0.8,
- type=float
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model filename',
- default=None,
- type=str
- )
- return parser.parse_args()
- class DataGenerator:
- def __init__(self, fs, X, epoch_step=1.):
- self.fs = int(fs)
- self.X = X
- self.epoch_step = epoch_step
- def get_data_batch(self, current_index):
- # return epoch_step length batch
- # create mne object
- ind = int(self.epoch_step * self.fs)
- data = self.X[:, current_index - ind:current_index].copy()
- return self.fs, [], data
- def loop(self, step_size=0.1):
- step = int(step_size * self.fs)
- for i in range(self.fs, self.X.shape[1] + 1, step):
- yield i / self.fs, self.get_data_batch(i)
- def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length):
- val_data = raw.get_data()
- fs = raw.info['sfreq']
- data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length)
- decision_with_hmm = []
- decision_without_hmm = []
- probs = []
- for time, data in data_gen.loop(step_length):
- step_p, cls = model_hmm.viterbi(data, return_step_p=True)
- if cls >=0:
- cls = model_hmm.model.classes_[cls]
- decision_with_hmm.append((time, cls)) # map to unified label
- decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
- probs.append((time, model_hmm.probability))
- probs = np.array(probs)
-
- events_pred = _construct_model_event(decision_with_hmm, fs)
- events_pred_naive = _construct_model_event(decision_without_hmm, fs)
- p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=fs)
- p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
- stim_true = _event_to_stim_channel(events, len(raw.times), trial_length=int(event_trial_length * fs))
- stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
- stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
- accu_hmm = accuracy_score(stim_true, stim_pred)
- accu_naive = accuracy_score(stim_true, stim_pred_naive)
- fig_pred, ax = plt.subplots(4, 1, sharex=True, sharey=False)
- ax[0].set_title('pred (naive)')
- ax[0].plot(raw.times, stim_pred_naive)
- ax[1].set_title('pred')
- ax[1].plot(raw.times, stim_pred)
- ax[2].set_title('true')
- ax[2].plot(raw.times, stim_true)
- ax[3].set_title('prob')
- ax[3].plot(probs[:, 0], probs[:, 1])
- ax[3].set_ylim([0, 1])
- return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
- def simulation(raw_val, event_id, model,
- epoch_length=1.,
- step_length=0.1,
- event_trial_length=5.):
- """模型验证接口,使用指定数据进行验证,绘制ersd map
- Args:
- raw (mne.io.Raw)
- event_id (dict)
- model: validate existing model,
- epoch_length (float): batch data length, default 1 (s)
- step_length (float): data step length, default 0.1 (s)
- event_trial_length (float):
- Returns:
- None
- """
- fs = raw_val.info['sfreq']
- events_val, _ = mne.events_from_annotations(raw_val, event_id)
- # run with and without hmm
- fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
- events_val,
- model,
- epoch_length,
- step_length,
- event_trial_length=event_trial_length)
- return metric_hmm, metric_naive, fig_pred
- def _construct_model_event(decision_seq, fs):
- events = []
- for i in decision_seq:
- time, cls = i
- if cls >= 0:
- events.append([int(time * fs), 0, cls])
- return np.array(events)
- def _event_to_stim_channel(events, time_length, trial_length=None):
- x = np.zeros(time_length)
- for i in range(0, len(events) - 1):
- if trial_length is not None:
- x[events[i, 0]: events[i, 0] + trial_length] = events[i, 2]
- else:
- x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
- return x
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- data_dir = f'./data/{subj_name}/'
-
- model_path = f'./static/models/{subj_name}/{args.model_filename}'
- with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- # preprocess raw
- trial_time = 5.
- raw, event_id = neo.raw_loader(data_dir, sessions,
- ori_epoch_length=trial_time,
- upsampled_epoch_length=None)
-
- # load model
- input_kwargs = {
- 'state_trans_prob': args.state_trans_prob,
- 'state_change_threshold': args.state_change_threshold
- }
- model_hmm = online.model_loader(model_path, **input_kwargs)
- # do validations
- metric_hmm, metric_naive, fig_pred = simulation(raw,
- event_id,
- model=model_hmm,
- epoch_length=config_info['buffer_length'],
- step_length=config_info['buffer_length'],
- event_trial_length=trial_time)
- fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
- logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
- logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')
|