''' 模型测试脚本, 测试AUC, 绘制Confusion matrix, ERSD map ''' import os import argparse import logging import mne import yaml import joblib import numpy as np from scipy import signal from sklearn.metrics import accuracy_score, f1_score import matplotlib.pyplot as plt from dataloaders import neo import bci_core.utils as bci_utils import bci_core.pipeline as bci_pipeline import bci_core.viz as bci_viz from settings.config import settings logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) return parser.parse_args() if __name__ == '__main__': args = parse_args() subj_name = args.subj data_dir = os.path.join(settings.DATA_PATH, subj_name) with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] ori_epoch_length = info.get('ori_epoch_length', 5.) # preprocess raw raw, event_id = neo.raw_loader(data_dir, sessions, ori_epoch_length=ori_epoch_length, reref_method=config_info['reref'], upsampled_epoch_length=None) fs = raw.info['sfreq'] events, _ = mne.events_from_annotations(raw, event_id) # ersd map fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, 2.5), event_id, 0) # tfr plot fig_tfr = bci_viz.plot_time_frequency(raw.get_data(), events, fs, np.arange(5, 200, 20), (-1, 4), {v: k for k, v in event_id.items()}) fig_erds.savefig(os.path.join(data_dir, 'erds.pdf')) fig_tfr.savefig(os.path.join(data_dir, 'tfr.pdf')) plt.show()