signal_visualization.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. '''
  2. Figures, ERSD map, tfr raw and cls
  3. '''
  4. import os
  5. import argparse
  6. import logging
  7. import mne
  8. import yaml
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. from dataloaders import neo
  12. import bci_core.viz as bci_viz
  13. from settings.config import settings
  14. logging.basicConfig(level=logging.INFO)
  15. logger = logging.getLogger(__name__)
  16. config_info = settings.CONFIG_INFO
  17. def parse_args():
  18. parser = argparse.ArgumentParser(
  19. description='Model validation'
  20. )
  21. parser.add_argument(
  22. '--subj',
  23. dest='subj',
  24. help='Subject name',
  25. default=None,
  26. type=str
  27. )
  28. return parser.parse_args()
  29. if __name__ == '__main__':
  30. args = parse_args()
  31. subj_name = args.subj
  32. data_dir = os.path.join(settings.DATA_PATH, subj_name)
  33. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  34. info = yaml.safe_load(f)
  35. sessions = info['sessions']
  36. ori_epoch_length = info.get('ori_epoch_length', 5.)
  37. # preprocess raw
  38. raw, event_id = neo.raw_loader(data_dir, sessions,
  39. ori_epoch_length=ori_epoch_length,
  40. reref_method=config_info['reref'],
  41. upsampled_epoch_length=None)
  42. fs = raw.info['sfreq']
  43. events, _ = mne.events_from_annotations(raw, event_id)
  44. # ersd map
  45. fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, 2.5), event_id, 0)
  46. # tfr plot
  47. fig_tfr = bci_viz.plot_cls_tfr(raw.get_data(), events, fs, np.arange(5, 200, 20), (-1, 4), {v: k for k, v in event_id.items()})
  48. # plot raw tfr
  49. fig_tfr_raw = bci_viz.plot_raw_tfr(raw.get_data(), fs, np.arange(5, 200, 10), n_cycles=20)
  50. # hg average line plot
  51. fig_hgs = {}
  52. for t in sessions.keys():
  53. fig_hg_1 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (55, 95), -1, 5, target_event=t)
  54. fig_hg_2 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (95, 155), -1, 5, t_smooth=0.6, target_event=t)
  55. fig_hgs[t] = (fig_hg_1, fig_hg_2)
  56. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  57. fig_tfr.savefig(os.path.join(data_dir, 'tfr.pdf'))
  58. fig_tfr_raw.savefig(os.path.join(data_dir, 'tfr_raw.pdf'))
  59. plt.show()