validation.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. '''
  2. 模型测试脚本,
  3. 测试AUC,
  4. 绘制Confusion matrix, ERSD map
  5. '''
  6. import numpy as np
  7. import joblib
  8. import mne
  9. import yaml
  10. import os
  11. import argparse
  12. import logging
  13. from scipy import signal
  14. from sklearn.metrics import accuracy_score, f1_score
  15. from dataloaders import neo
  16. import bci_core.utils as bci_utils
  17. import bci_core.viz as bci_viz
  18. from settings.config import settings
  19. logging.basicConfig(level=logging.INFO)
  20. logger = logging.getLogger(__name__)
  21. config_info = settings.CONFIG_INFO
  22. def parse_args():
  23. parser = argparse.ArgumentParser(
  24. description='Model validation'
  25. )
  26. parser.add_argument(
  27. '--subj',
  28. dest='subj',
  29. help='Subject name',
  30. default=None,
  31. type=str
  32. )
  33. parser.add_argument(
  34. '--model-filename',
  35. dest='model_filename',
  36. help='Model filename',
  37. default=None,
  38. type=str
  39. )
  40. return parser.parse_args()
  41. def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
  42. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  43. # parse model type
  44. model_type, _ = bci_utils.parse_model_type(model_path)
  45. if model_type == 'baseline':
  46. prob, y_pred = _val_by_epochs_baseline(raw, events, model_path, trial_duration)
  47. elif model_type == 'riemann':
  48. prob, y_pred = _val_by_epochs_riemann(raw, events, model_path, trial_duration)
  49. else:
  50. raise ValueError('Unaccepted model type')
  51. # metrices: AUC, accuracy,
  52. y = events[:, -1]
  53. auc = bci_utils.multiclass_auc_score(y, prob)
  54. accu = accuracy_score(y, y_pred)
  55. f1 = f1_score(y, y_pred, pos_label=np.max(y))
  56. # confusion matrix
  57. fig_conf = bci_viz.plot_confusion_matrix(y, y_pred)
  58. return (auc, accu, f1), fig_conf
  59. def _val_by_epochs_baseline(raw, events, model_path, duration):
  60. fs = raw.info['sfreq']
  61. feat_extractor, baseline_model = joblib.load(model_path)
  62. filter_bank_data = feat_extractor.transform(raw.get_data())
  63. filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
  64. # downsampling to 10 Hz
  65. # decim 2 times, to 100Hz
  66. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  67. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  68. # to 10Hz
  69. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  70. X = filter_bank_epoch
  71. # pred
  72. prob = baseline_model.predict_proba(X)
  73. y_pred = baseline_model.classes_[np.argmax(prob, axis=1)]
  74. return prob, y_pred
  75. def _val_by_epochs_riemann(raw, events, model_path, duration):
  76. fs = raw.info['sfreq']
  77. feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path)
  78. filtered_data = feat_extractor.transform(raw.get_data())
  79. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  80. X = scaler.transform(X)
  81. X_cov = cov_model.transform(X)
  82. # pred
  83. prob = riemann_model.predict_proba(X_cov)
  84. y_pred = riemann_model.classes_[np.argmax(prob, axis=1)]
  85. return prob, y_pred
  86. if __name__ == '__main__':
  87. args = parse_args()
  88. subj_name = args.subj
  89. data_dir = f'./data/{subj_name}/'
  90. model_path = f'./static/models/{subj_name}/{args.model_filename}'
  91. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  92. info = yaml.safe_load(f)
  93. sessions = info['sessions']
  94. # preprocess raw
  95. trial_time = 5.
  96. upsampled_trial_duration = config_info['buffer_length']
  97. raw, event_id = neo.raw_loader(data_dir, sessions,
  98. ori_epoch_length=trial_time,
  99. upsampled_epoch_length=upsampled_trial_duration)
  100. fs = raw.info['sfreq']
  101. events, _ = mne.events_from_annotations(raw, event_id)
  102. # ersd map
  103. fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, upsampled_trial_duration), event_id, 0)
  104. # Do validations
  105. metrices, fig_conf = val_by_epochs(raw, model_path, event_id, upsampled_trial_duration)
  106. # log results
  107. logger.info(f'Validation metrices: AUC: {metrices[0]:.4f}, Accuracy: {metrices[1]:.4f}, f1-score: {metrices[2]:.4f}')
  108. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  109. fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf'))