''' 模型测试脚本, 测试AUC, 绘制Confusion matrix, ERSD map ''' import os import argparse import logging import mne import yaml import joblib import numpy as np from scipy import signal from sklearn.metrics import accuracy_score, f1_score import matplotlib.pyplot as plt from dataloaders import neo import bci_core.utils as bci_utils import bci_core.viz as bci_viz from settings.config import settings logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() def val_by_epochs(raw, model_path, event_id, trial_duration=1., ): events, _ = mne.events_from_annotations(raw, event_id=event_id) # parse model type model_type, _ = bci_utils.parse_model_type(model_path) if model_type == 'baseline': prob, y_pred = _val_by_epochs_baseline(raw, events, model_path, trial_duration) elif model_type == 'riemann': prob, y_pred = _val_by_epochs_riemann(raw, events, model_path, trial_duration) else: raise ValueError('Unaccepted model type') # metrices: AUC, accuracy, y = events[:, -1] auc = bci_utils.multiclass_auc_score(y, prob) accu = accuracy_score(y, y_pred) f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro') # confusion matrix fig_conf = bci_viz.plot_confusion_matrix(y, y_pred) return (auc, accu, f1), fig_conf def _val_by_epochs_baseline(raw, events, model_path, duration): fs = raw.info['sfreq'] feat_extractor, baseline_model = joblib.load(model_path) filter_bank_data = feat_extractor.transform(raw.get_data()) filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0]) # downsampling to 10 Hz # decim 2 times, to 100Hz decimate_rate = np.sqrt(fs / 10).astype(np.int16) filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) # to 10Hz filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) X = filter_bank_epoch # pred prob = baseline_model.predict_proba(X) y_pred = baseline_model.classes_[np.argmax(prob, axis=1)] return prob, y_pred def _val_by_epochs_riemann(raw, events, model_path, duration): fs = raw.info['sfreq'] feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path) filtered_data = feat_extractor.transform(raw.get_data()) X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0]) X = scaler.transform(X) X_cov = cov_model.transform(X) # pred prob = riemann_model.predict_proba(X_cov) y_pred = riemann_model.classes_[np.argmax(prob, axis=1)] return prob, y_pred if __name__ == '__main__': args = parse_args() subj_name = args.subj data_dir = f'./data/{subj_name}/' model_path = f'./static/models/{subj_name}/{args.model_filename}' with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] # preprocess raw trial_time = 5. upsampled_trial_duration = config_info['buffer_length'] raw, event_id = neo.raw_loader(data_dir, sessions, ori_epoch_length=trial_time, upsampled_epoch_length=upsampled_trial_duration) fs = raw.info['sfreq'] events, _ = mne.events_from_annotations(raw, event_id) # ersd map fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, upsampled_trial_duration), event_id, 0) # Do validations metrices, fig_conf = val_by_epochs(raw, model_path, event_id, upsampled_trial_duration) # log results logger.info(f'Validation metrices: AUC: {metrices[0]:.4f}, Accuracy: {metrices[1]:.4f}, f1-score: {metrices[2]:.4f}') fig_erds.savefig(os.path.join(data_dir, 'erds.pdf')) fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf')) plt.show()