123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- '''
- 模型测试脚本,
- 测试AUC,
- 绘制Confusion matrix, ERSD map
- '''
- import os
- import argparse
- import logging
- import mne
- import yaml
- import joblib
- import numpy as np
- from scipy import signal
- from sklearn.metrics import accuracy_score, f1_score
- import matplotlib.pyplot as plt
- from dataloaders import neo
- import bci_core.utils as bci_utils
- import bci_core.viz as bci_viz
- from settings.config import settings
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model filename',
- default=None,
- type=str
- )
- return parser.parse_args()
- def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
- events, _ = mne.events_from_annotations(raw, event_id=event_id)
- # parse model type
- model_type, _ = bci_utils.parse_model_type(model_path)
- if model_type == 'baseline':
- prob, y_pred = _val_by_epochs_baseline(raw, events, model_path, trial_duration)
- elif model_type == 'riemann':
- prob, y_pred = _val_by_epochs_riemann(raw, events, model_path, trial_duration)
- else:
- raise ValueError('Unaccepted model type')
-
- # metrices: AUC, accuracy,
- y = events[:, -1]
- auc = bci_utils.multiclass_auc_score(y, prob)
- accu = accuracy_score(y, y_pred)
- f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')
- # confusion matrix
- fig_conf = bci_viz.plot_confusion_matrix(y, y_pred)
- return (auc, accu, f1), fig_conf
- def _val_by_epochs_baseline(raw, events, model_path, duration):
- fs = raw.info['sfreq']
- feat_extractor, baseline_model = joblib.load(model_path)
- filter_bank_data = feat_extractor.transform(raw.get_data())
- filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
- # downsampling to 10 Hz
- # decim 2 times, to 100Hz
- decimate_rate = np.sqrt(fs / 10).astype(np.int16)
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- # to 10Hz
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- X = filter_bank_epoch
- # pred
- prob = baseline_model.predict_proba(X)
- y_pred = baseline_model.classes_[np.argmax(prob, axis=1)]
- return prob, y_pred
- def _val_by_epochs_riemann(raw, events, model_path, duration):
- fs = raw.info['sfreq']
- feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path)
- filtered_data = feat_extractor.transform(raw.get_data())
- X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
- X = scaler.transform(X)
- X_cov = cov_model.transform(X)
- # pred
- prob = riemann_model.predict_proba(X_cov)
- y_pred = riemann_model.classes_[np.argmax(prob, axis=1)]
- return prob, y_pred
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- data_dir = f'./data/{subj_name}/'
-
- model_path = f'./static/models/{subj_name}/{args.model_filename}'
- with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- # preprocess raw
- trial_time = 5.
- upsampled_trial_duration = config_info['buffer_length']
- raw, event_id = neo.raw_loader(data_dir, sessions,
- ori_epoch_length=trial_time,
- upsampled_epoch_length=upsampled_trial_duration)
-
- fs = raw.info['sfreq']
-
- events, _ = mne.events_from_annotations(raw, event_id)
- # ersd map
- fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, upsampled_trial_duration), event_id, 0)
- # Do validations
- metrices, fig_conf = val_by_epochs(raw, model_path, event_id, upsampled_trial_duration)
- # log results
- logger.info(f'Validation metrices: AUC: {metrices[0]:.4f}, Accuracy: {metrices[1]:.4f}, f1-score: {metrices[2]:.4f}')
-
- fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
- fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf'))
- plt.show()
|