band_selection.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os
  2. import argparse
  3. import matplotlib.pyplot as plt
  4. import yaml
  5. import numpy as np
  6. from bci_core.frequencybandselection_helpers import freq_selection_class_dis
  7. from dataloaders import neo
  8. from sklearn.model_selection import ShuffleSplit
  9. from settings.config import settings
  10. config_info = settings.CONFIG_INFO
  11. def parse_args():
  12. parser = argparse.ArgumentParser(
  13. description='Model validation'
  14. )
  15. parser.add_argument(
  16. '--subj',
  17. dest='subj',
  18. help='Subject name',
  19. default=None,
  20. type=str
  21. )
  22. parser.add_argument(
  23. '--band-min',
  24. dest='b_min',
  25. help='Band lower range',
  26. default=5.,
  27. type=float
  28. )
  29. parser.add_argument(
  30. '--band-max',
  31. dest='b_max',
  32. help='Band upper range',
  33. default=45.,
  34. type=float
  35. )
  36. return parser.parse_args()
  37. args = parse_args()
  38. subj_name = args.subj
  39. data_dir = f'./data/{subj_name}/'
  40. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  41. info = yaml.safe_load(f)
  42. sessions = info['sessions']
  43. trial_duration = config_info['buffer_length']
  44. # preprocess raw
  45. raw, event_id = neo.raw_loader(data_dir, sessions, do_rereference=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
  46. ###############################################################################
  47. # Pipeline with a frequency band selection based on the class distinctiveness
  48. # ----------------------------------------------------------------------------
  49. #
  50. # Step1: Select frequency band maximizing class distinctiveness on
  51. # training set.
  52. #
  53. # Define parameters for frequency band selection
  54. freq_band = [args.b_min, args.b_max]
  55. sub_band_width = 5.
  56. sub_band_step = 5.
  57. alpha = 0.4
  58. # cross validation
  59. cv = ShuffleSplit(n_splits=1, test_size=None, random_state=42)
  60. # Select frequency band using training set
  61. best_freq, all_class_dis = \
  62. freq_selection_class_dis(raw, freq_band, sub_band_width,
  63. sub_band_step, alpha,
  64. tmin=0., tmax=0.5,
  65. cv=cv,
  66. return_class_dis=True, verbose=False)
  67. print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz')
  68. ###############################################################################
  69. # Plot selected frequency bands
  70. # ----------------------------------
  71. #
  72. # Plot the class distinctiveness values for each sub_band,
  73. # along with the highlight of the finally selected frequency band.
  74. subband_fmin = np.arange(freq_band[0],
  75. freq_band[1] - sub_band_width + 1.,
  76. sub_band_step)
  77. subband_fmax = np.arange(freq_band[0] + sub_band_width,
  78. freq_band[1] + 1., sub_band_step)
  79. n_subband = len(subband_fmin)
  80. subband_fmean = (subband_fmin + subband_fmax) / 2
  81. x = subband_fmean
  82. fig, ax = plt.subplots(1, 1, figsize=(8, 5))
  83. ax.plot(x, all_class_dis[0], marker='o')
  84. ax.set_ylabel('Class distinctiveness')
  85. ax.set_xlabel('Filter bank [Hz]')
  86. ax.set_title('Class distinctiveness value of each subband')
  87. ax.tick_params(labelsize='large')
  88. fig.tight_layout()
  89. fig.savefig(os.path.join(data_dir, 'freq_selection.pdf'))
  90. print(f'Optimal frequency band for this subject is '
  91. f'{best_freq[0][0]} - {best_freq[0][1]} Hz')
  92. plt.show()