validation.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. '''
  2. 模型测试脚本,
  3. 测试AUC,
  4. 绘制Confusion matrix, ERSD map
  5. '''
  6. import os
  7. import argparse
  8. import logging
  9. import mne
  10. import yaml
  11. import joblib
  12. import numpy as np
  13. from scipy import signal
  14. from sklearn.metrics import accuracy_score, f1_score
  15. import matplotlib.pyplot as plt
  16. from dataloaders import neo
  17. import bci_core.utils as bci_utils
  18. import bci_core.viz as bci_viz
  19. from settings.config import settings
  20. logging.basicConfig(level=logging.INFO)
  21. logger = logging.getLogger(__name__)
  22. config_info = settings.CONFIG_INFO
  23. def parse_args():
  24. parser = argparse.ArgumentParser(
  25. description='Model validation'
  26. )
  27. parser.add_argument(
  28. '--subj',
  29. dest='subj',
  30. help='Subject name',
  31. default=None,
  32. type=str
  33. )
  34. parser.add_argument(
  35. '--model-filename',
  36. dest='model_filename',
  37. help='Model filename',
  38. default=None,
  39. type=str
  40. )
  41. return parser.parse_args()
  42. def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
  43. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  44. # parse model type
  45. model_type, _ = bci_utils.parse_model_type(model_path)
  46. if model_type == 'baseline':
  47. prob, y_pred = _val_by_epochs_baseline(raw, events, model_path, trial_duration)
  48. elif model_type == 'riemann':
  49. prob, y_pred = _val_by_epochs_riemann(raw, events, model_path, trial_duration)
  50. else:
  51. raise ValueError('Unaccepted model type')
  52. # metrices: AUC, accuracy,
  53. y = events[:, -1]
  54. auc = bci_utils.multiclass_auc_score(y, prob)
  55. accu = accuracy_score(y, y_pred)
  56. f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')
  57. # confusion matrix
  58. fig_conf = bci_viz.plot_confusion_matrix(y, y_pred)
  59. return (auc, accu, f1), fig_conf
  60. def _val_by_epochs_baseline(raw, events, model_path, duration):
  61. fs = raw.info['sfreq']
  62. feat_extractor, baseline_model = joblib.load(model_path)
  63. filter_bank_data = feat_extractor.transform(raw.get_data())
  64. filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
  65. # downsampling to 10 Hz
  66. # decim 2 times, to 100Hz
  67. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  68. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  69. # to 10Hz
  70. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  71. X = filter_bank_epoch
  72. # pred
  73. prob = baseline_model.predict_proba(X)
  74. y_pred = baseline_model.classes_[np.argmax(prob, axis=1)]
  75. return prob, y_pred
  76. def _val_by_epochs_riemann(raw, events, model_path, duration):
  77. fs = raw.info['sfreq']
  78. feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path)
  79. filtered_data = feat_extractor.transform(raw.get_data())
  80. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  81. X = scaler.transform(X)
  82. X_cov = cov_model.transform(X)
  83. # pred
  84. prob = riemann_model.predict_proba(X_cov)
  85. y_pred = riemann_model.classes_[np.argmax(prob, axis=1)]
  86. return prob, y_pred
  87. if __name__ == '__main__':
  88. args = parse_args()
  89. subj_name = args.subj
  90. data_dir = f'./data/{subj_name}/'
  91. model_path = f'./static/models/{subj_name}/{args.model_filename}'
  92. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  93. info = yaml.safe_load(f)
  94. sessions = info['sessions']
  95. # preprocess raw
  96. trial_time = 5.
  97. upsampled_trial_duration = config_info['buffer_length']
  98. raw, event_id = neo.raw_loader(data_dir, sessions,
  99. ori_epoch_length=trial_time,
  100. upsampled_epoch_length=upsampled_trial_duration)
  101. fs = raw.info['sfreq']
  102. events, _ = mne.events_from_annotations(raw, event_id)
  103. # ersd map
  104. fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, upsampled_trial_duration), event_id, 0)
  105. # Do validations
  106. metrices, fig_conf = val_by_epochs(raw, model_path, event_id, upsampled_trial_duration)
  107. # log results
  108. logger.info(f'Validation metrices: AUC: {metrices[0]:.4f}, Accuracy: {metrices[1]:.4f}, f1-score: {metrices[2]:.4f}')
  109. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  110. fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf'))
  111. plt.show()