signal_visualization.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. '''
  2. 模型测试脚本,
  3. 测试AUC,
  4. 绘制Confusion matrix, ERSD map
  5. '''
  6. import os
  7. import argparse
  8. import logging
  9. import mne
  10. import yaml
  11. import joblib
  12. import numpy as np
  13. from scipy import signal
  14. from sklearn.metrics import accuracy_score, f1_score
  15. import matplotlib.pyplot as plt
  16. from dataloaders import neo
  17. import bci_core.utils as bci_utils
  18. import bci_core.pipeline as bci_pipeline
  19. import bci_core.viz as bci_viz
  20. from settings.config import settings
  21. logging.basicConfig(level=logging.INFO)
  22. logger = logging.getLogger(__name__)
  23. config_info = settings.CONFIG_INFO
  24. def parse_args():
  25. parser = argparse.ArgumentParser(
  26. description='Model validation'
  27. )
  28. parser.add_argument(
  29. '--subj',
  30. dest='subj',
  31. help='Subject name',
  32. default=None,
  33. type=str
  34. )
  35. return parser.parse_args()
  36. if __name__ == '__main__':
  37. args = parse_args()
  38. subj_name = args.subj
  39. data_dir = os.path.join(settings.DATA_PATH, subj_name)
  40. with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
  41. info = yaml.safe_load(f)
  42. sessions = info['sessions']
  43. ori_epoch_length = info.get('ori_epoch_length', 5.)
  44. # preprocess raw
  45. raw, event_id = neo.raw_loader(data_dir, sessions,
  46. ori_epoch_length=ori_epoch_length,
  47. reref_method=config_info['reref'],
  48. upsampled_epoch_length=None)
  49. fs = raw.info['sfreq']
  50. events, _ = mne.events_from_annotations(raw, event_id)
  51. # ersd map
  52. fig_erds = bci_viz.plot_ersd(raw.get_data(), events, fs, (0, 2.5), event_id, 0)
  53. # tfr plot
  54. fig_tfr = bci_viz.plot_time_frequency(raw.get_data(), events, fs, np.arange(5, 200, 20), (-1, 4), {v: k for k, v in event_id.items()})
  55. fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
  56. fig_tfr.savefig(os.path.join(data_dir, 'tfr.pdf'))
  57. plt.show()