validation.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. '''
  2. 模型模拟在线测试脚本
  3. 数据两折分割,1折训练模型,1折按照在线模式测试:decison AUC + event f1-score
  4. '''
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mne
  8. import yaml
  9. import os
  10. import joblib
  11. from scipy import stats
  12. from dataloaders import neo
  13. import bci_core.online as online
  14. import bci_core.utils as bci_utils
  15. import bci_core.viz as bci_viz
  16. class DataGenerator:
  17. def __init__(self, fs, X):
  18. self.fs = int(fs)
  19. self.X = X
  20. def get_data_batch(self, current_index):
  21. # return 1s batch
  22. # create mne object
  23. data = self.X[:, current_index - self.fs:current_index].copy()
  24. return self.fs, [], data
  25. def loop(self, step_size=0.1):
  26. step = int(step_size * self.fs)
  27. for i in range(self.fs, self.X.shape[1] + 1, step):
  28. yield i / self.fs, self.get_data_batch(i)
  29. def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8):
  30. """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
  31. Args:
  32. raw (mne.io.Raw)
  33. model_type (string): type of model to train, baseline or riemann
  34. event_id (dict)
  35. model: validate existing model,
  36. state_change_threshold (float): default 0.8
  37. Returns:
  38. None
  39. """
  40. fs = raw_val.info['sfreq']
  41. events_val, _ = mne.events_from_annotations(raw_val, event_id)
  42. # plot ersd map
  43. fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0)
  44. events_val = neo.reconstruct_events(events_val,
  45. fs,
  46. finger_model=None,
  47. rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
  48. mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
  49. use_original_label=True)
  50. if model_type == 'baseline':
  51. hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold)
  52. else:
  53. raise NotImplementedError
  54. controller = online.Controller(0, None)
  55. controller.set_real_feedback_model(hmm_model)
  56. # validate with the second half
  57. val_data = raw_val.get_data()
  58. data_gen = DataGenerator(fs, val_data)
  59. rets = []
  60. for time, data in data_gen.loop():
  61. cls = controller.decision(data)
  62. rets.append((time, cls))
  63. events_pred = _construct_model_event(rets, fs)
  64. precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
  65. stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
  66. stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
  67. corr, p = stats.pearsonr(stim_pred, stim_true)
  68. fig_pred, ax = plt.subplots(1, 1)
  69. ax.plot(raw_val.times, stim_pred, label='pred')
  70. ax.plot(raw_val.times, stim_true, label='true')
  71. ax.legend()
  72. return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
  73. def _construct_model_event(decision_seq, fs):
  74. events = []
  75. for i in decision_seq:
  76. time, cls = i
  77. if cls >= 0:
  78. events.append([int(time * fs), 0, cls])
  79. return np.array(events)
  80. def _event_to_stim_channel(events, time_length):
  81. x = np.zeros(time_length)
  82. for i in range(0, len(events) - 1):
  83. x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
  84. return x
  85. if __name__ == '__main__':
  86. # TODO: argparse
  87. subj_name = 'ylj'
  88. model_type = 'baseline'
  89. # TODO: load subject config
  90. data_dir = f'./data/{subj_name}/val/'
  91. model_path = f'./static/models/{subj_name}/scis.pkl'
  92. info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
  93. sessions = info['sessions']
  94. event_id = {'rest': 0}
  95. for f in sessions.keys():
  96. event_id[f] = neo.FINGERMODEL_IDS[f]
  97. # preprocess raw
  98. raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5)
  99. # load model
  100. model = joblib.load(model_path)
  101. model_type, events = bci_utils.parse_model_type(model_path)
  102. metrics, fig_erds, fig_pred = validation(raw,
  103. model_type,
  104. event_id,
  105. model=model,
  106. state_change_threshold=0.8)
  107. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  108. fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
  109. print(metrics)