1
0

validation.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. '''
  2. 模型模拟在线测试脚本
  3. 数据两折分割,1折训练模型,1折按照在线模式测试:decison AUC + event f1-score
  4. '''
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mne
  8. import yaml
  9. import os
  10. import logging
  11. from scipy import stats
  12. from dataloaders import neo
  13. import bci_core.online as online
  14. import bci_core.utils as bci_utils
  15. import bci_core.viz as bci_viz
  16. logging.basicConfig(level=logging.DEBUG)
  17. class DataGenerator:
  18. def __init__(self, fs, X):
  19. self.fs = int(fs)
  20. self.X = X
  21. def get_data_batch(self, current_index):
  22. # return 1s batch
  23. # create mne object
  24. data = self.X[:, current_index - self.fs:current_index].copy()
  25. return self.fs, [], data
  26. def loop(self, step_size=0.1):
  27. step = int(step_size * self.fs)
  28. for i in range(self.fs, self.X.shape[1] + 1, step):
  29. yield i / self.fs, self.get_data_batch(i)
  30. def validation(raw_val, event_id, model, state_change_threshold=0.8):
  31. """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
  32. Args:
  33. raw (mne.io.Raw)
  34. event_id (dict)
  35. model: validate existing model,
  36. state_change_threshold (float): default 0.8
  37. Returns:
  38. None
  39. """
  40. fs = raw_val.info['sfreq']
  41. events_val, _ = mne.events_from_annotations(raw_val, event_id)
  42. # plot ersd map
  43. fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0)
  44. events_val = neo.reconstruct_events(events_val,
  45. fs,
  46. finger_model=None,
  47. rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
  48. mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
  49. use_original_label=True)
  50. controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
  51. # validate with the second half
  52. val_data = raw_val.get_data()
  53. data_gen = DataGenerator(fs, val_data)
  54. rets = []
  55. for time, data in data_gen.loop():
  56. cls = controller.decision(data)
  57. rets.append((time, cls))
  58. events_pred = _construct_model_event(rets, fs)
  59. precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
  60. stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
  61. stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
  62. corr, p = stats.pearsonr(stim_pred, stim_true)
  63. fig_pred, ax = plt.subplots(1, 1)
  64. ax.plot(raw_val.times, stim_pred, label='pred')
  65. ax.plot(raw_val.times, stim_true, label='true')
  66. ax.legend()
  67. return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
  68. def _construct_model_event(decision_seq, fs):
  69. events = []
  70. for i in decision_seq:
  71. time, cls = i
  72. if cls >= 0:
  73. events.append([int(time * fs), 0, cls])
  74. return np.array(events)
  75. def _event_to_stim_channel(events, time_length):
  76. x = np.zeros(time_length)
  77. for i in range(0, len(events) - 1):
  78. x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
  79. return x
  80. if __name__ == '__main__':
  81. # TODO: argparse
  82. subj_name = 'ylj'
  83. # TODO: load subject config
  84. data_dir = f'./data/{subj_name}/'
  85. model_path = f'./static/models/{subj_name}/baseline_rest+cylinder_11-19-2023-17-31-18.pkl'
  86. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  87. info = yaml.safe_load(f)
  88. sessions = info['sessions']
  89. event_id = {'rest': 0}
  90. for f in sessions.keys():
  91. event_id[f] = neo.FINGERMODEL_IDS[f]
  92. # preprocess raw
  93. raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7])
  94. # do validations
  95. metrics, fig_erds, fig_pred = validation(raw,
  96. event_id,
  97. model=model_path,
  98. state_change_threshold=0.8)
  99. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  100. fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
  101. logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')