2
0

online_sim.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. '''
  2. 模型模拟在线测试脚本
  3. 在线模式测试:event f1-score and decision trace
  4. '''
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mne
  8. import yaml
  9. import os
  10. import argparse
  11. import logging
  12. from sklearn.metrics import accuracy_score
  13. from dataloaders import neo
  14. import bci_core.online as online
  15. import bci_core.utils as bci_utils
  16. import bci_core.viz as bci_viz
  17. from settings.config import settings
  18. logging.basicConfig(level=logging.DEBUG)
  19. logger = logging.getLogger(__name__)
  20. config_info = settings.CONFIG_INFO
  21. def parse_args():
  22. parser = argparse.ArgumentParser(
  23. description='Model validation'
  24. )
  25. parser.add_argument(
  26. '--subj',
  27. dest='subj',
  28. help='Subject name',
  29. default=None,
  30. type=str
  31. )
  32. parser.add_argument(
  33. '--state-change-threshold',
  34. '-scth',
  35. dest='state_change_threshold',
  36. help='Threshold for HMM state change',
  37. default=0.75,
  38. type=float
  39. )
  40. parser.add_argument(
  41. '--state-trans-prob',
  42. '-stp',
  43. dest='state_trans_prob',
  44. help='Transition probability for HMM state change',
  45. default=0.8,
  46. type=float
  47. )
  48. parser.add_argument(
  49. '--momentum',
  50. help='Probability update momentum',
  51. default=0.5,
  52. type=float
  53. )
  54. parser.add_argument(
  55. '--model-filename',
  56. dest='model_filename',
  57. help='Model filename',
  58. default=None,
  59. type=str
  60. )
  61. return parser.parse_args()
  62. class DataGenerator:
  63. def __init__(self, fs, X, epoch_step=1.):
  64. self.fs = int(fs)
  65. self.X = X
  66. self.epoch_step = epoch_step
  67. def get_data_batch(self, current_index):
  68. # return epoch_step length batch
  69. # create mne object
  70. ind = int(self.epoch_step * self.fs)
  71. data = self.X[:, current_index - ind:current_index].copy()
  72. return self.fs, [], data
  73. def loop(self, step_size=0.1):
  74. step = int(step_size * self.fs)
  75. for i in range(int(self.epoch_step * self.fs), self.X.shape[1] + 1, step):
  76. yield i / self.fs, self.get_data_batch(i)
  77. @property
  78. def time_range(self):
  79. return self.epoch_step, self.X.shape[1] / self.fs
  80. def time_steps(self, step_size=0.1):
  81. step = int(step_size * self.fs)
  82. return len(list(range(int(self.epoch_step * self.fs), self.X.shape[1] + 1, step)))
  83. def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length):
  84. val_data = raw.get_data()
  85. fs = raw.info['sfreq']
  86. data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length)
  87. # events -> 1 / step_length
  88. events[:, 0] = (events[:, 0] / fs / step_length).astype(np.int32)
  89. decision_with_hmm = []
  90. decision_without_hmm = []
  91. probs = []
  92. probs_naive = []
  93. prob_timestamps = []
  94. skip_flag = False
  95. for time, (fs, event, data) in data_gen.loop(step_length):
  96. # skip 3 seconds if skip flag is True
  97. if skip_flag and (time - tic) < 3:
  98. continue
  99. else:
  100. skip_flag = False
  101. step_p, cls = model_hmm.viterbi(fs, data, return_step_p=True)
  102. if cls >= 0:
  103. cls = model_hmm.model.classes_[cls]
  104. decision_with_hmm.append((time, cls)) # map to unified label
  105. decision_without_hmm.append((time, model_hmm.model.classes_[np.argmax(step_p)]))
  106. prob_timestamps.append(time)
  107. probs.append(model_hmm.probability)
  108. probs_naive.append(step_p)
  109. # start timer when cls == 0
  110. if cls == 0 and not skip_flag:
  111. skip_flag = True
  112. tic = time
  113. probs = np.array(probs)
  114. probs_naive = np.array(probs_naive)
  115. prob_timestamps = np.array(prob_timestamps)
  116. events_pred = _construct_model_event(decision_with_hmm, 1 / step_length)
  117. events_pred_naive = _construct_model_event(decision_without_hmm, 1 / step_length)
  118. p_hmm, r_hmm, f1_hmm, latency_hmm = bci_utils.event_metric(event_true=events, event_pred=events_pred, fs=1 / step_length)
  119. p_n, r_n, f1_n, _ = bci_utils.event_metric(events, events_pred_naive, fs=1 / step_length)
  120. time_steps = data_gen.time_steps(step_length)
  121. start_ind = int(data_gen.time_range[0] / step_length)
  122. if event_trial_length == 'varied':
  123. event_trial_length = None
  124. elif isinstance(event_trial_length, dict):
  125. event_trial_length = {k: int(v[-1] / step_length) for k, v in event_trial_length.items()}
  126. else:
  127. event_trial_length = int(event_trial_length / step_length)
  128. stim_true = bci_utils.event_to_stim_channel(events, time_steps, trial_length=event_trial_length, start_ind=start_ind)
  129. stim_pred = bci_utils.event_to_stim_channel(events_pred, time_steps, start_ind=start_ind)
  130. stim_pred_naive = bci_utils.event_to_stim_channel(events_pred_naive, time_steps, start_ind=start_ind)
  131. accu_hmm = accuracy_score(stim_true, stim_pred)
  132. accu_naive = accuracy_score(stim_true, stim_pred_naive)
  133. # hmm plot
  134. fig_hmm, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, figsize=(10, 8))
  135. axes[0].set_title('True states')
  136. bci_viz.plot_states(data_gen.time_range, stim_true, ax=axes[0])
  137. axes[1].set_title('State sequence')
  138. bci_viz.plot_states(data_gen.time_range, stim_pred, ax=axes[1])
  139. for i, ax in enumerate(axes[2:]):
  140. bci_viz.plot_state_prob_with_cue(true_states=stim_true,
  141. pred_probs=probs[:, i],
  142. times=prob_timestamps, ax=ax)
  143. fig_hmm.suptitle('With HMM')
  144. # naive plot
  145. fig_naive, axes = plt.subplots(model_hmm.n_classes + 2, 1, sharex=True, sharey=True, figsize=(10, 8))
  146. axes[0].set_title('True states')
  147. bci_viz.plot_states(data_gen.time_range, stim_true, ax=axes[0])
  148. axes[1].set_title('State sequence')
  149. bci_viz.plot_states(data_gen.time_range, stim_pred_naive, ax=axes[1])
  150. for i, ax in enumerate(axes[2:]):
  151. bci_viz.plot_state_prob_with_cue(true_states=stim_true,
  152. pred_probs=probs_naive[:, i],
  153. times=prob_timestamps, ax=ax)
  154. fig_naive.suptitle('Naive')
  155. return (fig_hmm, fig_naive), (p_hmm, r_hmm, f1_hmm, accu_hmm, latency_hmm), (p_n, r_n, f1_n, accu_naive)
  156. def simulation(raw_val, event_id, model,
  157. epoch_length=1.,
  158. step_length=0.1,
  159. event_trial_length=5.):
  160. """模型验证接口,使用指定数据进行验证,绘制ersd map
  161. Args:
  162. raw (mne.io.Raw)
  163. event_id (dict)
  164. model: validate existing model,
  165. epoch_length (float): batch data length, default 1 (s)
  166. step_length (float): data step length, default 0.1 (s)
  167. event_trial_length (float or dict or 'varied'):
  168. Returns:
  169. None
  170. """
  171. fs = raw_val.info['sfreq']
  172. events_val, _ = mne.events_from_annotations(raw_val, event_id)
  173. # run with and without hmm
  174. fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
  175. events_val,
  176. model,
  177. epoch_length,
  178. step_length,
  179. event_trial_length=event_trial_length)
  180. return metric_hmm, metric_naive, fig_pred
  181. def _construct_model_event(decision_seq, fs, start_cond=0):
  182. def _filter_seq(decision_seq):
  183. new_seq = [(decision_seq[0][0], start_cond)]
  184. for i in range(1, len(decision_seq)):
  185. if decision_seq[i][1] == -1:
  186. new_seq.append((decision_seq[i][0], new_seq[-1][1]))
  187. else:
  188. new_seq.append(decision_seq[i])
  189. return new_seq
  190. decision_seq = _filter_seq(decision_seq)
  191. last_state = decision_seq[0][1]
  192. events = [(int(decision_seq[0][0] * fs), 0, last_state)]
  193. for i in range(1, len(decision_seq)):
  194. time, label = decision_seq[i]
  195. if label != last_state:
  196. last_state = label
  197. events.append([int(time * fs), 0, label])
  198. return np.array(events)
  199. if __name__ == '__main__':
  200. args = parse_args()
  201. subj_name = args.subj
  202. data_dir = os.path.join(settings.DATA_PATH, subj_name)
  203. model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename)
  204. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  205. info = yaml.safe_load(f)
  206. sessions = info['sessions']
  207. # preprocess raw
  208. trial_time = info.get('ori_epoch_length', 'varied')
  209. raw, event_id = neo.raw_loader(data_dir, sessions,
  210. reref_method=config_info['reref'],
  211. ori_epoch_length=trial_time,
  212. upsampled_epoch_length=None)
  213. # load model
  214. input_kwargs = {
  215. 'state_trans_prob': args.state_trans_prob,
  216. 'state_change_threshold': args.state_change_threshold,
  217. 'momentum': args.momentum
  218. }
  219. model_hmm = online.model_loader(model_path, **input_kwargs)
  220. # do online simulation
  221. metric_hmm, metric_naive, fig_pred = simulation(raw,
  222. event_id,
  223. model=model_hmm,
  224. epoch_length=config_info['buffer_length'],
  225. step_length=0.1,
  226. event_trial_length=trial_time)
  227. fig_pred[0].savefig(os.path.join(data_dir, 'pred_hmm.pdf'))
  228. fig_pred[1].savefig(os.path.join(data_dir, 'pred_naive.pdf'))
  229. logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}, latency: {np.mean(metric_hmm[4]):.4f} +- {np.std(metric_hmm[4]):.4f}')
  230. logger.info(f'Without HMM: precision: {metric_naive[0]:.4f}, recall: {metric_naive[1]:.4f}, f1_score: {metric_naive[2]:.4f}, accuracy: {metric_naive[3]:.4f}')
  231. plt.show()