1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- '''
- 模型测试脚本,
- 测试AUC,
- 绘制Confusion matrix, ERSD map
- '''
- import os
- import argparse
- import logging
- import mne
- import yaml
- import joblib
- import numpy as np
- from scipy import signal
- from sklearn.metrics import accuracy_score, f1_score
- import matplotlib.pyplot as plt
- from dataloaders import neo
- import bci_core.utils as bci_utils
- import bci_core.pipeline as bci_pipeline
- import bci_core.viz as bci_viz
- from settings.config import settings
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- return parser.parse_args()
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- data_dir = os.path.join(settings.DATA_PATH, subj_name)
-
- with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- ori_epoch_length = info.get('ori_epoch_length', 5.)
- # preprocess raw
- raw, event_id = neo.raw_loader(data_dir, sessions,
- ori_epoch_length=ori_epoch_length,
- reref_method=config_info['reref'],
- upsampled_epoch_length=None)
-
- fs = raw.info['sfreq']
-
- events, _ = mne.events_from_annotations(raw, event_id)
- # ersd map
- fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, 2.5), event_id, 0)
- # tfr plot
- fig_tfr = bci_viz.plot_time_frequency(raw.get_data(), events, fs, np.arange(5, 200, 20), (-1, 4), {v: k for k, v in event_id.items()})
-
- fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
- fig_tfr.savefig(os.path.join(data_dir, 'tfr.pdf'))
- plt.show()
|