band_selection.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. import argparse
  3. import matplotlib.pyplot as plt
  4. import yaml
  5. import numpy as np
  6. from bci_core.frequencybandselection_helpers import freq_selection_class_dis
  7. from dataloaders import neo
  8. from sklearn.model_selection import ShuffleSplit
  9. from settings.config import settings
  10. config_info = settings.CONFIG_INFO
  11. def parse_args():
  12. parser = argparse.ArgumentParser(
  13. description='Model validation'
  14. )
  15. parser.add_argument(
  16. '--subj',
  17. dest='subj',
  18. help='Subject name',
  19. default=None,
  20. type=str
  21. )
  22. parser.add_argument(
  23. '--band-min',
  24. dest='b_min',
  25. help='Band lower range',
  26. default=5.,
  27. type=float
  28. )
  29. parser.add_argument(
  30. '--band-max',
  31. dest='b_max',
  32. help='Band upper range',
  33. default=200.,
  34. type=float
  35. )
  36. parser.add_argument(
  37. '--band-div',
  38. help='Dividing point of low frequency bands and high frequency bands',
  39. default=50.,
  40. type=float
  41. )
  42. return parser.parse_args()
  43. args = parse_args()
  44. subj_name = args.subj
  45. data_dir = os.path.join(settings.DATA_PATH, subj_name)
  46. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  47. info = yaml.safe_load(f)
  48. sessions = info['sessions']
  49. trial_duration = config_info['buffer_length']
  50. ori_epoch_length = info.get('ori_epoch_length', 5.)
  51. # preprocess raw
  52. raw, event_id = neo.raw_loader(data_dir, sessions, reref_method=config_info['reref'], upsampled_epoch_length=trial_duration, ori_epoch_length=ori_epoch_length)
  53. ###############################################################################
  54. # Pipeline with a frequency band selection based on the class distinctiveness
  55. # ----------------------------------------------------------------------------
  56. #
  57. # Step1: Select frequency band maximizing class distinctiveness on
  58. # training set.
  59. #
  60. # Define parameters for frequency band selection
  61. freq_band = [args.b_min, args.b_max]
  62. alpha = 0.4
  63. subband_fmean = np.logspace(np.log10(args.b_min), np.log10(args.b_max), 30)
  64. f_step = (np.log10(args.b_max) - np.log10(args.b_min)) / 30
  65. subband_fmin = subband_fmean * (10 ** (-f_step))
  66. subband_fmax = subband_fmean * (10 ** f_step)
  67. # cross validation
  68. cv = ShuffleSplit(n_splits=1, test_size=None, random_state=42)
  69. # Select frequency band using training set
  70. best_freq, all_class_dis = \
  71. freq_selection_class_dis(raw, subband_fmin, subband_fmax, alpha, args.band_div,
  72. tmin=0., tmax=0.5,
  73. cv=cv,
  74. return_class_dis=True, verbose=False)
  75. print(f'Selected frequency band (low) : {best_freq[0][0][0]} - {best_freq[0][0][1]} Hz')
  76. print(f'Selected frequency band (high) : {best_freq[0][1][0]} - {best_freq[0][1][1]} Hz')
  77. # split high frequency band
  78. n_bands = np.round((best_freq[0][1][1] - best_freq[0][1][0]) / 50 + 1).astype(int)
  79. split_high_freqs = np.logspace(np.log10(best_freq[0][1][0]), np.log10(best_freq[0][1][1]), n_bands)
  80. print('Splited high frequency bands: ', split_high_freqs)
  81. ###############################################################################
  82. # Plot selected frequency bands
  83. # ----------------------------------
  84. #
  85. # Plot the class distinctiveness values for each sub_band,
  86. # along with the highlight of the finally selected frequency band.
  87. fig, ax = plt.subplots(1, 1, figsize=(8, 5))
  88. ax.plot(subband_fmean, all_class_dis[0], marker='o')
  89. ax.axvspan(best_freq[0][0][0], best_freq[0][0][1], color='orange', alpha=0.5)
  90. ax.axvspan(best_freq[0][1][0], best_freq[0][1][1], color='orange', alpha=0.5)
  91. ax.set_ylabel('Class distinctiveness')
  92. ax.set_xlabel('Filter bank [Hz]')
  93. ax.set_title('Class distinctiveness value of each subband')
  94. ax.tick_params(labelsize='large')
  95. fig.tight_layout()
  96. fig.savefig(os.path.join(data_dir, 'freq_selection.pdf'))
  97. plt.show()