''' Figures, ERSD map, tfr raw and cls ''' import os import argparse import logging import mne import yaml import numpy as np import matplotlib.pyplot as plt from dataloaders import neo import bci_core.viz as bci_viz from settings.config import settings logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) return parser.parse_args() if __name__ == '__main__': args = parse_args() subj_name = args.subj data_dir = os.path.join(settings.DATA_PATH, subj_name) with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] ori_epoch_length = info.get('ori_epoch_length', 5.) # preprocess raw raw, event_id = neo.raw_loader(data_dir, sessions, ori_epoch_length=ori_epoch_length, reref_method=config_info['reref'], upsampled_epoch_length=None) fs = raw.info['sfreq'] events, _ = mne.events_from_annotations(raw, event_id) # ersd map fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, 2.5), event_id, 0) # tfr plot fig_tfr = bci_viz.plot_cls_tfr(raw.get_data(), events, fs, np.arange(5, 200, 20), (-1, 4), {v: k for k, v in event_id.items()}) # plot raw tfr fig_tfr_raw = bci_viz.plot_raw_tfr(raw.get_data(), fs, np.arange(5, 200, 10), n_cycles=20) # hg average line plot fig_hgs = {} for t in sessions.keys(): fig_hg_1 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (55, 95), -1, 5, target_event=t) fig_hg_2 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (95, 155), -1, 5, t_smooth=0.6, target_event=t) fig_hgs[t] = (fig_hg_1, fig_hg_2) fig_erds.savefig(os.path.join(data_dir, 'erds.pdf')) fig_tfr.savefig(os.path.join(data_dir, 'tfr.pdf')) fig_tfr_raw.savefig(os.path.join(data_dir, 'tfr_raw.pdf')) plt.show()