2
0

validation.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. '''
  2. 模型模拟在线测试脚本
  3. 在线模式测试:event f1-score and decision trace
  4. '''
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mne
  8. import yaml
  9. import os
  10. import argparse
  11. import logging
  12. from scipy import stats
  13. from dataloaders import neo
  14. import bci_core.online as online
  15. import bci_core.utils as bci_utils
  16. import bci_core.viz as bci_viz
  17. from settings.config import settings
  18. logging.basicConfig(level=logging.DEBUG)
  19. config_info = settings.CONFIG_INFO
  20. def parse_args():
  21. parser = argparse.ArgumentParser(
  22. description='Model validation'
  23. )
  24. parser.add_argument(
  25. '--subj',
  26. dest='subj',
  27. help='Subject name',
  28. default=None,
  29. type=str
  30. )
  31. parser.add_argument(
  32. '--state-change-threshold',
  33. '-scth',
  34. dest='state_change_threshold',
  35. help='Threshold for HMM state change',
  36. default=0.75,
  37. type=float
  38. )
  39. parser.add_argument(
  40. '--model-filename',
  41. dest='model_filename',
  42. help='Model filename',
  43. default=None,
  44. type=str
  45. )
  46. return parser.parse_args()
  47. class DataGenerator:
  48. def __init__(self, fs, X, epoch_step=1.):
  49. self.fs = int(fs)
  50. self.X = X
  51. self.epoch_step = epoch_step
  52. def get_data_batch(self, current_index):
  53. # return epoch_step length batch
  54. # create mne object
  55. ind = int(self.epoch_step * self.fs)
  56. data = self.X[:, current_index - ind:current_index].copy()
  57. return self.fs, [], data
  58. def loop(self, step_size=0.1):
  59. step = int(step_size * self.fs)
  60. for i in range(self.fs, self.X.shape[1] + 1, step):
  61. yield i / self.fs, self.get_data_batch(i)
  62. def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
  63. """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
  64. Args:
  65. raw (mne.io.Raw)
  66. event_id (dict)
  67. model: validate existing model,
  68. state_change_threshold (float): default 0.8
  69. step_length (float): batch data step length, default 1. (s)
  70. Returns:
  71. None
  72. """
  73. fs = raw_val.info['sfreq']
  74. events_val, _ = mne.events_from_annotations(raw_val, event_id)
  75. # plot ersd map
  76. fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0)
  77. events_val = neo.reconstruct_events(events_val,
  78. fs,
  79. finger_model=None,
  80. rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
  81. mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
  82. use_original_label=True)
  83. controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
  84. # validate with the second half
  85. val_data = raw_val.get_data()
  86. data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
  87. decisions = []
  88. probs = []
  89. for time, data in data_gen.loop():
  90. cls = controller.decision(data)
  91. decisions.append((time, cls))
  92. probs.append((time, controller.real_feedback_model.probability))
  93. probs = np.array(probs)
  94. events_pred = _construct_model_event(decisions, fs)
  95. precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
  96. stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
  97. stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
  98. corr, _ = stats.pearsonr(stim_pred, stim_true)
  99. fig_pred, ax = plt.subplots(3, 1, sharex=True, sharey=False)
  100. ax[0].set_title('pred')
  101. ax[0].plot(raw_val.times, stim_pred)
  102. ax[1].set_title('true')
  103. ax[1].plot(raw_val.times, stim_true)
  104. ax[2].set_title('prob')
  105. ax[2].plot(probs[:, 0], probs[:, 1])
  106. ax[2].set_ylim([0, 1])
  107. return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
  108. def _construct_model_event(decision_seq, fs):
  109. events = []
  110. for i in decision_seq:
  111. time, cls = i
  112. if cls >= 0:
  113. events.append([int(time * fs), 0, cls])
  114. return np.array(events)
  115. def _event_to_stim_channel(events, time_length):
  116. x = np.zeros(time_length)
  117. for i in range(0, len(events) - 1):
  118. x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
  119. return x
  120. if __name__ == '__main__':
  121. args = parse_args()
  122. subj_name = args.subj
  123. # TODO: load subject config
  124. data_dir = f'./data/{subj_name}/'
  125. model_path = f'./static/models/{subj_name}/{args.model_filename}'
  126. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  127. info = yaml.safe_load(f)
  128. sessions = info['sessions']
  129. event_id = {'rest': 0}
  130. for f in sessions.keys():
  131. event_id[f] = neo.FINGERMODEL_IDS[f]
  132. # preprocess raw
  133. raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
  134. # do validations
  135. metrics, fig_erds, fig_pred = validation(raw,
  136. event_id,
  137. model=model_path,
  138. state_change_threshold=args.state_change_threshold,
  139. step_length=config_info['buffer_length'])
  140. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  141. fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
  142. logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')