validation.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. '''
  2. 模型模拟在线测试脚本
  3. 数据两折分割,1折训练模型,1折按照在线模式测试:decison AUC + event f1-score
  4. '''
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mne
  8. import yaml
  9. import os
  10. import joblib
  11. from scipy import stats
  12. from dataloaders import neo
  13. import training
  14. import bci_core.online as online
  15. import bci_core.utils as bci_utils
  16. import bci_core.viz as bci_viz
  17. class DataGenerator:
  18. def __init__(self, fs, X):
  19. self.fs = int(fs)
  20. self.X = X
  21. def get_data_batch(self, current_index):
  22. # return 1s batch
  23. # create mne object
  24. data = self.X[:, current_index - self.fs:current_index].copy()
  25. # append event channel
  26. data = np.concatenate((data, np.zeros((1, data.shape[1]))), axis=0)
  27. info = mne.create_info([f'S{i}' for i in range(len(data))], self.fs, ['ecog'] * (len(data) - 1) + ['misc'])
  28. raw = mne.io.RawArray(data, info, verbose=False)
  29. return {'data': raw}
  30. def loop(self, step_size=0.1):
  31. step = int(step_size * self.fs)
  32. for i in range(self.fs, self.X.shape[1] + 1, step):
  33. yield i / self.fs, self.get_data_batch(i)
  34. def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8):
  35. """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
  36. Args:
  37. raw (mne.io.Raw)
  38. model_type (string): type of model to train, baseline or riemann
  39. event_id (dict)
  40. model: validate existing model,
  41. state_change_threshold (float): default 0.8
  42. Returns:
  43. None
  44. """
  45. fs = raw_val.info['sfreq']
  46. events_val, _ = mne.events_from_annotations(raw_val, event_id)
  47. # plot ersd map
  48. fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events_val, fs, (0, 1), event_id, 0)
  49. events_val = neo.reconstruct_events(events_val,
  50. fs,
  51. finger_model=None,
  52. rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
  53. mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
  54. use_original_label=True)
  55. if model_type == 'baseline':
  56. hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold)
  57. else:
  58. raise NotImplementedError
  59. controller = online.Controller(0, None)
  60. controller.set_real_feedback_model(hmm_model)
  61. # validate with the second half
  62. val_data = raw_val.get_data()
  63. data_gen = DataGenerator(fs, val_data)
  64. rets = []
  65. for time, data in data_gen.loop():
  66. cls = controller.decision(data)
  67. rets.append((time, cls))
  68. events_pred = _construct_model_event(rets, fs)
  69. precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
  70. stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
  71. stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
  72. corr, p = stats.pearsonr(stim_pred, stim_true)
  73. fig_pred, ax = plt.subplots(1, 1)
  74. ax.plot(raw_val.times, stim_pred, label='pred')
  75. ax.plot(raw_val.times, stim_true, label='true')
  76. ax.legend()
  77. return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
  78. def _construct_model_event(decision_seq, fs):
  79. events = []
  80. for i in decision_seq:
  81. time, cls = i
  82. if cls >= 0:
  83. events.append([int(time * fs), 0, cls])
  84. return np.array(events)
  85. def _event_to_stim_channel(events, time_length):
  86. x = np.zeros(time_length)
  87. for i in range(0, len(events) - 1):
  88. x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
  89. return x
  90. if __name__ == '__main__':
  91. subj_name = 'ylj'
  92. model_type = 'baseline'
  93. # TODO: load subject config
  94. data_dir = f'./data/{subj_name}/val/'
  95. model_path = f'./static/models/{subj_name}/scis.pkl'
  96. info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
  97. sessions = info['sessions']
  98. event_id = {'rest': 0}
  99. for f in sessions.keys():
  100. event_id[f] = neo.FINGERMODEL_IDS[f]
  101. # preprocess raw
  102. raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False)
  103. # load model
  104. model = joblib.load(model_path)
  105. model_type, events = bci_utils.parse_model_type(model_path)
  106. metrics, fig_erds, fig_pred = validation(raw,
  107. model_type,
  108. event_id,
  109. model=model,
  110. state_change_threshold=0.8)
  111. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  112. fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
  113. print(metrics)