1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- '''
- Figures, ERSD map, tfr raw and cls
- '''
- import os
- import argparse
- import logging
- import mne
- import yaml
- import numpy as np
- import matplotlib.pyplot as plt
- from dataloaders import neo
- import bci_core.viz as bci_viz
- from settings.config import settings
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- return parser.parse_args()
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- data_dir = os.path.join(settings.DATA_PATH, subj_name)
-
- with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- ori_epoch_length = info.get('ori_epoch_length', 5.)
- # preprocess raw
- raw, event_id = neo.raw_loader(data_dir, sessions,
- ori_epoch_length=ori_epoch_length,
- reref_method=config_info['reref'],
- upsampled_epoch_length=None)
-
- fs = raw.info['sfreq']
-
- events, _ = mne.events_from_annotations(raw, event_id)
- # ersd map
- fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, 2.5), event_id, 0)
- # tfr plot
- fig_tfr = bci_viz.plot_cls_tfr(raw.get_data(), events, fs, np.arange(5, 200, 20), (-1, 4), {v: k for k, v in event_id.items()})
- # plot raw tfr
- fig_tfr_raw = bci_viz.plot_raw_tfr(raw.get_data(), fs, np.arange(5, 200, 10), n_cycles=20)
- # hg average line plot
- fig_hgs = {}
- for t in sessions.keys():
- fig_hg_1 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (55, 95), -1, 5, target_event=t)
- fig_hg_2 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (95, 155), -1, 5, t_smooth=0.6, target_event=t)
- fig_hgs[t] = (fig_hg_1, fig_hg_2)
-
- fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
- fig_tfr.savefig(os.path.join(data_dir, 'tfr.pdf'))
- fig_tfr_raw.savefig(os.path.join(data_dir, 'tfr_raw.pdf'))
- plt.show()
|