band_selection.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import argparse
  3. import matplotlib.pyplot as plt
  4. import yaml
  5. import numpy as np
  6. from bci_core.frequencybandselection_helpers import freq_selection_class_dis
  7. from dataloaders import neo
  8. from sklearn.model_selection import ShuffleSplit
  9. from settings.config import settings
  10. config_info = settings.CONFIG_INFO
  11. def parse_args():
  12. parser = argparse.ArgumentParser(
  13. description='Model validation'
  14. )
  15. parser.add_argument(
  16. '--subj',
  17. dest='subj',
  18. help='Subject name',
  19. default=None,
  20. type=str
  21. )
  22. parser.add_argument(
  23. '--band-min',
  24. dest='b_min',
  25. help='Band lower range',
  26. default=5.,
  27. type=float
  28. )
  29. parser.add_argument(
  30. '--band-max',
  31. dest='b_max',
  32. help='Band upper range',
  33. default=45.,
  34. type=float
  35. )
  36. return parser.parse_args()
  37. args = parse_args()
  38. subj_name = args.subj
  39. data_dir = f'./data/{subj_name}/'
  40. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  41. info = yaml.safe_load(f)
  42. sessions = info['sessions']
  43. event_id = {'rest': 0}
  44. for f in sessions.keys():
  45. event_id[f] = neo.FINGERMODEL_IDS[f]
  46. trial_duration = config_info['buffer_length']
  47. # preprocess raw
  48. raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
  49. ###############################################################################
  50. # Pipeline with a frequency band selection based on the class distinctiveness
  51. # ----------------------------------------------------------------------------
  52. #
  53. # Step1: Select frequency band maximizing class distinctiveness on
  54. # training set.
  55. #
  56. # Define parameters for frequency band selection
  57. freq_band = [args.b_min, args.b_max]
  58. sub_band_width = 5.
  59. sub_band_step = 5.
  60. alpha = 0.4
  61. # cross validation
  62. cv = ShuffleSplit(n_splits=1, test_size=None, random_state=42)
  63. # Select frequency band using training set
  64. best_freq, all_class_dis = \
  65. freq_selection_class_dis(raw, freq_band, sub_band_width,
  66. sub_band_step, alpha,
  67. tmin=0., tmax=0.5,
  68. cv=cv,
  69. return_class_dis=True, verbose=False)
  70. print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz')
  71. ###############################################################################
  72. # Plot selected frequency bands
  73. # ----------------------------------
  74. #
  75. # Plot the class distinctiveness values for each sub_band,
  76. # along with the highlight of the finally selected frequency band.
  77. subband_fmin = np.arange(freq_band[0],
  78. freq_band[1] - sub_band_width + 1.,
  79. sub_band_step)
  80. subband_fmax = np.arange(freq_band[0] + sub_band_width,
  81. freq_band[1] + 1., sub_band_step)
  82. n_subband = len(subband_fmin)
  83. subband_fmean = (subband_fmin + subband_fmax) / 2
  84. x = subband_fmean
  85. fig, ax = plt.subplots(1, 1, figsize=(8, 5))
  86. ax.plot(x, all_class_dis[0], marker='o')
  87. ax.set_ylabel('Class distinctiveness')
  88. ax.set_xlabel('Filter bank [Hz]')
  89. ax.set_title('Class distinctiveness value of each subband')
  90. ax.tick_params(labelsize='large')
  91. fig.tight_layout()
  92. fig.savefig(os.path.join(data_dir, 'freq_selection.pdf'))
  93. print(f'Optimal frequency band for this subject is '
  94. f'{best_freq[0][0]} - {best_freq[0][1]} Hz')
  95. plt.show()