123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- '''
- 模型测试脚本,
- 测试AUC,
- 绘制Confusion matrix, ERSD map
- '''
- import os
- import argparse
- import logging
- import mne
- import yaml
- import joblib
- import numpy as np
- from scipy import signal
- from sklearn.metrics import accuracy_score, f1_score
- import matplotlib.pyplot as plt
- from dataloaders import neo
- import bci_core.utils as bci_utils
- import bci_core.pipeline as bci_pipeline
- import bci_core.viz as bci_viz
- from settings.config import settings
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model filename',
- default=None,
- type=str
- )
- return parser.parse_args()
- def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
- events, _ = mne.events_from_annotations(raw, event_id=event_id)
- # parse model type
- models = joblib.load(model_path)
- prob, y_pred = bci_pipeline.data_evaluation(models, raw.get_data(), raw.info['sfreq'], events, trial_duration, True)
-
- # metrices: AUC, accuracy,
- y = events[:, -1]
- auc = bci_utils.multiclass_auc_score(y, prob)
- accu = accuracy_score(y, y_pred)
- f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')
- # confusion matrix
- fig_conf = bci_viz.plot_confusion_matrix(y, y_pred)
- return (auc, accu, f1), fig_conf
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- data_dir = os.path.join(settings.DATA_PATH, subj_name)
- model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename)
-
- with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- ori_epoch_length = info.get('ori_epoch_length', 5.)
- upsampled_trial_duration = config_info['buffer_length']
- # preprocess raw
- raw, event_id = neo.raw_loader(data_dir, sessions,
- ori_epoch_length=ori_epoch_length,
- reref_method=config_info['reref'],
- upsampled_epoch_length=upsampled_trial_duration)
-
- fs = raw.info['sfreq']
-
- events, _ = mne.events_from_annotations(raw, event_id)
- # Do validations
- metrices, fig_conf = val_by_epochs(raw, model_path, event_id, upsampled_trial_duration)
- # log results
- logger.info(f'Validation metrices: AUC: {metrices[0]:.4f}, Accuracy: {metrices[1]:.4f}, f1-score: {metrices[2]:.4f}')
-
- fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf'))
- plt.show()
|