123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- '''
- 模型模拟在线测试脚本
- 在线模式测试:event f1-score and decision trace
- '''
- import numpy as np
- import matplotlib.pyplot as plt
- import mne
- import yaml
- import os
- import argparse
- import logging
- from sklearn.metrics import accuracy_score
- from dataloaders import neo
- import bci_core.online as online
- import bci_core.utils as bci_utils
- import bci_core.viz as bci_viz
- from settings.config import settings
- logging.basicConfig(level=logging.DEBUG)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--state-change-threshold',
- '-scth',
- dest='state_change_threshold',
- help='Threshold for HMM state change',
- default=0.75,
- type=float
- )
- parser.add_argument(
- '--state-trans-prob',
- '-stp',
- dest='state_trans_prob',
- help='Transition probability for HMM state change',
- default=0.8,
- type=float
- )
- parser.add_argument(
- '--momentum',
- help='Probability update momentum',
- default=0.5,
- type=float
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model filename',
- default=None,
- type=str
- )
- return parser.parse_args()
- class DataGenerator:
- def __init__(self, fs, X, epoch_step=1.):
- self.fs = int(fs)
- self.X = X
- self.epoch_step = epoch_step
- def get_data_batch(self, current_index):
- # return epoch_step length batch
- # create mne object
- ind = int(self.epoch_step * self.fs)
- data = self.X[:, current_index - ind:current_index].copy()
- return self.fs, [], data
- def loop(self, step_size=0.1):
- step = int(step_size * self.fs)
- for i in range(int(self.epoch_step * self.fs), self.X.shape[1] + 1, step):
- yield i / self.fs, self.get_data_batch(i)
-
- @property
- def time_range(self):
- return self.epoch_step, self.X.shape[1] / self.fs
-
- def time_steps(self, step_size=0.1):
- step = int(step_size * self.fs)
- return len(list(range(int(self.epoch_step * self.fs), self.X.shape[1] + 1, step)))
- def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length):
- val_data = raw.get_data()
- fs = raw.info['sfreq']
- data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length)
- # events -> 1 / step_length
- events[:, 0] = (events[:, 0] / fs / step_length).astype(np.int32)
- decision_with_hmm = []
- decision_without_hmm = []
- probs = []
- probs_naive = []
- for time, (fs, event, data) in data_gen.loop(step_length):
- step_p, cls = model_hmm.viterbi(fs, data, return_step_p=True)
- if cls >=0:
- cls = model_hmm.model.classes_[cls]
- decision_with_hmm.append((time, cls)) # map to unified label
- decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
- probs.append(model_hmm.probability)
- probs_naive.append(step_p)
- probs = np.array(probs)
- probs_naive = np.array(probs_naive)
-
- events_pred = _construct_model_event(decision_with_hmm, 1 / step_length)
- events_pred_naive = _construct_model_event(decision_without_hmm, 1 / step_length)
- p_hmm, r_hmm, f1_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=1 / step_length)
- p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=1 / step_length)
-
- time_steps = data_gen.time_steps(step_length)
- start_ind = int(data_gen.time_range[0] / step_length)
- stim_true = bci_utils.event_to_stim_channel(events, time_steps, trial_length=int(event_trial_length / step_length), start_ind=start_ind)
- stim_pred = bci_utils.event_to_stim_channel(events_pred, time_steps, start_ind=start_ind)
- stim_pred_naive = bci_utils.event_to_stim_channel(events_pred_naive, time_steps, start_ind=start_ind)
- accu_hmm = accuracy_score(stim_true, stim_pred)
- accu_naive = accuracy_score(stim_true, stim_pred_naive)
-
- # hmm plot
- fig_hmm, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
- axes[0].set_title('True states')
- bci_viz.plot_states(data_gen.time_range, stim_true, ax=axes[0])
- axes[1].set_title('State sequence')
- bci_viz.plot_states(data_gen.time_range, stim_pred, ax=axes[1])
- for i, ax in enumerate(axes[2:]):
- bci_viz.plot_state_prob_with_cue(data_gen.time_range, stim_true, probs[:, i], ax=ax)
- fig_hmm.suptitle('With HMM')
-
- # naive plot
- fig_naive, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, sharey=True, figsize=(10, 8))
- axes[0].set_title('True states')
- bci_viz.plot_states(data_gen.time_range, stim_true, ax=axes[0])
- axes[1].set_title('State sequence')
- bci_viz.plot_states(data_gen.time_range, stim_pred_naive, ax=axes[1])
- for i, ax in enumerate(axes[2:]):
- bci_viz.plot_state_prob_with_cue(data_gen.time_range, stim_true, probs_naive[:, i], ax=ax)
- fig_naive.suptitle('Naive')
- return (fig_hmm, fig_naive), (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
- def simulation(raw_val, event_id, model,
- epoch_length=1.,
- step_length=0.1,
- event_trial_length=5.):
- """模型验证接口,使用指定数据进行验证,绘制ersd map
- Args:
- raw (mne.io.Raw)
- event_id (dict)
- model: validate existing model,
- epoch_length (float): batch data length, default 1 (s)
- step_length (float): data step length, default 0.1 (s)
- event_trial_length (float):
- Returns:
- None
- """
- fs = raw_val.info['sfreq']
- events_val, _ = mne.events_from_annotations(raw_val, event_id)
- # run with and without hmm
- fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
- events_val,
- model,
- epoch_length,
- step_length,
- event_trial_length=event_trial_length)
- return metric_hmm, metric_naive, fig_pred
- def _construct_model_event(decision_seq, fs, start_cond=0):
- def _filter_seq(decision_seq):
- new_seq = [(decision_seq[0][0], start_cond)]
- for i in range(1, len(decision_seq)):
- if decision_seq[i][1] == -1:
- new_seq.append((decision_seq[i][0], new_seq[-1][1]))
- else:
- new_seq.append(decision_seq[i])
- return new_seq
- decision_seq = _filter_seq(decision_seq)
- last_state = decision_seq[0][1]
- events = [(int(decision_seq[0][0] * fs), 0, last_state)]
- for i in range(1, len(decision_seq)):
- time, label = decision_seq[i]
- if label != last_state:
- last_state = label
- events.append([int(time * fs), 0, label])
- return np.array(events)
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- data_dir = os.path.join(settings.DATA_PATH, subj_name)
- model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename)
- with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- # preprocess raw
- trial_time = 5.
- raw, event_id = neo.raw_loader(data_dir, sessions,
- reref_method=config_info['reref'],
- ori_epoch_length=trial_time,
- upsampled_epoch_length=None)
-
- # load model
- input_kwargs = {
- 'state_trans_prob': args.state_trans_prob,
- 'state_change_threshold': args.state_change_threshold,
- 'momentum': args.momentum
- }
- model_hmm = online.model_loader(model_path, **input_kwargs)
- # do online simulation
- metric_hmm, metric_naive, fig_pred = simulation(raw,
- event_id,
- model=model_hmm,
- epoch_length=config_info['buffer_length'],
- step_length=0.1,
- event_trial_length=trial_time)
- fig_pred[0].savefig(os.path.join(data_dir, 'pred_hmm.pdf'))
- fig_pred[1].savefig(os.path.join(data_dir, 'pred_naive.pdf'))
- logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')
- logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')
- plt.show()
|