''' 模型测试脚本, 测试AUC, 绘制Confusion matrix, ERSD map ''' import os import argparse import logging import mne import yaml import joblib import numpy as np from scipy import signal from sklearn.metrics import accuracy_score, f1_score import matplotlib.pyplot as plt from dataloaders import neo import bci_core.utils as bci_utils import bci_core.pipeline as bci_pipeline import bci_core.viz as bci_viz from settings.config import settings logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() def val_by_epochs(raw, model_path, event_id, trial_duration=1., ): events, _ = mne.events_from_annotations(raw, event_id=event_id) # parse model type models = joblib.load(model_path) prob, y_pred = bci_pipeline.data_evaluation(models, raw.get_data(), raw.info['sfreq'], events, trial_duration, True) # metrices: AUC, accuracy, y = events[:, -1] auc = bci_utils.multiclass_auc_score(y, prob) accu = accuracy_score(y, y_pred) f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro') # confusion matrix fig_conf = bci_viz.plot_confusion_matrix(y, y_pred) return (auc, accu, f1), fig_conf if __name__ == '__main__': args = parse_args() subj_name = args.subj data_dir = os.path.join(settings.DATA_PATH, subj_name) model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename) with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] ori_epoch_length = info.get('ori_epoch_length', 5.) upsampled_trial_duration = config_info['buffer_length'] # preprocess raw raw, event_id = neo.raw_loader(data_dir, sessions, ori_epoch_length=ori_epoch_length, reref_method=config_info['reref'], upsampled_epoch_length=upsampled_trial_duration) fs = raw.info['sfreq'] events, _ = mne.events_from_annotations(raw, event_id) # Do validations metrices, fig_conf = val_by_epochs(raw, model_path, event_id, upsampled_trial_duration) # log results logger.info(f'Validation metrices: AUC: {metrices[0]:.4f}, Accuracy: {metrices[1]:.4f}, f1-score: {metrices[2]:.4f}') fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf')) plt.show()