2
0

validation.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. '''
  2. 模型模拟在线测试脚本
  3. 在线模式测试:event f1-score and decision trace
  4. '''
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mne
  8. import yaml
  9. import os
  10. import logging
  11. from scipy import stats
  12. from dataloaders import neo
  13. import bci_core.online as online
  14. import bci_core.utils as bci_utils
  15. import bci_core.viz as bci_viz
  16. from settings.config import settings
  17. logging.basicConfig(level=logging.INFO)
  18. config_info = settings.CONFIG_INFO
  19. class DataGenerator:
  20. def __init__(self, fs, X, epoch_step=1.):
  21. self.fs = int(fs)
  22. self.X = X
  23. self.epoch_step = epoch_step
  24. def get_data_batch(self, current_index):
  25. # return epoch_step length batch
  26. # create mne object
  27. ind = int(self.epoch_step * self.fs)
  28. data = self.X[:, current_index - ind:current_index].copy()
  29. return self.fs, [], data
  30. def loop(self, step_size=0.1):
  31. step = int(step_size * self.fs)
  32. for i in range(self.fs, self.X.shape[1] + 1, step):
  33. yield i / self.fs, self.get_data_batch(i)
  34. def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
  35. """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
  36. Args:
  37. raw (mne.io.Raw)
  38. event_id (dict)
  39. model: validate existing model,
  40. state_change_threshold (float): default 0.8
  41. step_length (float): batch data step length, default 1. (s)
  42. Returns:
  43. None
  44. """
  45. fs = raw_val.info['sfreq']
  46. events_val, _ = mne.events_from_annotations(raw_val, event_id)
  47. # plot ersd map
  48. fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0)
  49. events_val = neo.reconstruct_events(events_val,
  50. fs,
  51. finger_model=None,
  52. rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
  53. mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
  54. use_original_label=True)
  55. controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
  56. # validate with the second half
  57. val_data = raw_val.get_data()
  58. data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
  59. rets = []
  60. for time, data in data_gen.loop():
  61. cls = controller.decision(data)
  62. rets.append((time, cls))
  63. events_pred = _construct_model_event(rets, fs)
  64. precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
  65. stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
  66. stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
  67. corr, p = stats.pearsonr(stim_pred, stim_true)
  68. fig_pred, ax = plt.subplots(1, 1)
  69. ax.plot(raw_val.times, stim_pred, label='pred')
  70. ax.plot(raw_val.times, stim_true, label='true')
  71. ax.legend()
  72. return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
  73. def _construct_model_event(decision_seq, fs):
  74. events = []
  75. for i in decision_seq:
  76. time, cls = i
  77. if cls >= 0:
  78. events.append([int(time * fs), 0, cls])
  79. return np.array(events)
  80. def _event_to_stim_channel(events, time_length):
  81. x = np.zeros(time_length)
  82. for i in range(0, len(events) - 1):
  83. x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
  84. return x
  85. if __name__ == '__main__':
  86. # TODO: argparse
  87. subj_name = 'XW01'
  88. # TODO: load subject config
  89. data_dir = f'./data/{subj_name}/'
  90. model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-16-43-23.pkl'
  91. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  92. info = yaml.safe_load(f)
  93. sessions = info['sessions']
  94. event_id = {'rest': 0}
  95. for f in sessions.keys():
  96. event_id[f] = neo.FINGERMODEL_IDS[f]
  97. # preprocess raw
  98. raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
  99. # do validations
  100. metrics, fig_erds, fig_pred = validation(raw,
  101. event_id,
  102. model=model_path,
  103. state_change_threshold=0.75,
  104. step_length=config_info['buffer_length'])
  105. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  106. fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
  107. logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')