1
0

validation.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. '''
  2. 模型模拟在线测试脚本
  3. 数据两折分割,1折训练模型,1折按照在线模式测试:decison AUC + event f1-score
  4. '''
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mne
  8. import yaml
  9. import os
  10. import joblib
  11. from scipy import stats
  12. from dataloaders import neo
  13. import training
  14. import bci_core.online as online
  15. import bci_core.utils as bci_utils
  16. import bci_core.viz as bci_viz
  17. class DataGenerator:
  18. def __init__(self, fs, X):
  19. self.fs = int(fs)
  20. self.X = X
  21. def get_data_batch(self, current_index):
  22. # return 1s batch
  23. # create mne object
  24. data = self.X[:, current_index - self.fs:current_index].copy()
  25. # append event channel
  26. data = np.concatenate((data, np.zeros((1, data.shape[1]))), axis=0)
  27. info = mne.create_info([f'S{i}' for i in range(len(data))], self.fs, ['ecog'] * (len(data) - 1) + ['misc'])
  28. raw = mne.io.RawArray(data, info, verbose=False)
  29. return {'data': raw}
  30. def loop(self, step_size=0.1):
  31. step = int(step_size * self.fs)
  32. for i in range(self.fs, self.X.shape[1] + 1, step):
  33. yield i / self.fs, self.get_data_batch(i)
  34. def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8):
  35. """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
  36. Args:
  37. raw (mne.io.Raw)
  38. model_type (string): type of model to train, baseline or riemann
  39. event_id (dict)
  40. model: validate existing model,
  41. state_change_threshold (float): default 0.8
  42. Returns:
  43. None
  44. """
  45. fs = raw_val.info['sfreq']
  46. # plot ersd map
  47. events, _ = mne.events_from_annotations(raw_val, event_id)
  48. fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events, fs, (0, 1), event_id, 0)
  49. events_val, _ = mne.events_from_annotations(raw_val, event_id)
  50. events_val = neo.reconstruct_events(events_val,
  51. fs,
  52. finger_model=None,
  53. rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
  54. mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
  55. use_original_label=True)
  56. if model_type == 'baseline':
  57. hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold)
  58. else:
  59. raise NotImplementedError
  60. controller = online.Controller(0, None)
  61. controller.set_real_feedback_model(hmm_model)
  62. # validate with the second half
  63. val_data = raw_val.get_data()
  64. data_gen = DataGenerator(fs, val_data)
  65. rets = []
  66. for time, data in data_gen.loop():
  67. cls = controller.decision(data)
  68. rets.append((time, cls))
  69. events_pred = _construct_model_event(rets, fs)
  70. precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
  71. stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
  72. stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
  73. corr, p = stats.pearsonr(stim_pred, stim_true)
  74. fig_pred, ax = plt.subplots(1, 1)
  75. ax.plot(raw_val.times, stim_pred, label='pred')
  76. ax.plot(raw_val.times, stim_true, label='true')
  77. ax.legend()
  78. return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
  79. def _construct_model_event(decision_seq, fs):
  80. events = []
  81. for i in decision_seq:
  82. time, cls = i
  83. if cls >= 0:
  84. events.append([int(time * fs), 0, cls])
  85. return np.array(events)
  86. def _event_to_stim_channel(events, time_length):
  87. x = np.zeros(time_length)
  88. for i in range(0, len(events) - 1):
  89. x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
  90. return x
  91. if __name__ == '__main__':
  92. subj_name = 'ylj'
  93. model_type = 'baseline'
  94. # TODO: load subject config
  95. data_dir = f'./data/{subj_name}/val/'
  96. model_path = f'./static/models/{subj_name}/scis.pkl'
  97. info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
  98. sessions = info['sessions']
  99. event_id = {'rest': 0}
  100. for f in sessions.keys():
  101. event_id[f] = neo.FINGERMODEL_IDS[f]
  102. # preprocess raw
  103. raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False)
  104. # load model
  105. model = joblib.load(model_path)
  106. model_type, events = bci_utils.parse_model_type(model_path)
  107. metrics, fig_erds, fig_pred = validation(raw,
  108. model_type,
  109. event_id,
  110. model=model,
  111. state_change_threshold=0.8)
  112. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  113. fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
  114. print(metrics)