frequencybandselection_helpers.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. """
  2. =================================
  3. Frequency Band Selection Helpers
  4. =================================
  5. This file contains helper functions for the frequency band selection example
  6. """
  7. import numpy as np
  8. from mne import Epochs, events_from_annotations
  9. from pyriemann.estimation import Covariances, Shrinkage
  10. from pyriemann.classification import class_distinctiveness
  11. def freq_selection_class_dis(raw, freq_band=(5., 35.), sub_band_width=4,
  12. sub_band_step=2, alpha=0.4,
  13. tmin=0.5, tmax=2.5,
  14. picks=None, event_id=None,
  15. cv=None,
  16. return_class_dis=False, verbose=None):
  17. r"""Select optimal frequency band based on class distinctiveness measure.
  18. Optimal frequency band is selected by combining a filter bank with
  19. a heuristic based on class distinctiveness on the manifold [1]_:
  20. 1. Filter training raw EEG data for each sub-band of a filter bank.
  21. 2. Estimate covariance matrices of filtered EEG for each sub-band.
  22. 3. Measure the class distinctiveness of each sub-band using the
  23. classDis metric.
  24. 4. Find the optimal frequency band by starting from the sub-band
  25. with the largest classDis and expanding the selected frequency band
  26. as long as the classDis exceeds the threshold :math:`th`:
  27. .. math::
  28. th = max(classDis) - (max(classDis)−min(classDis)) × alpha
  29. Parameters
  30. ----------
  31. raw : Raw object
  32. An instance of Raw from MNE.
  33. freq_band : list, default=(5., 35.)
  34. Frequency band for band-pass filtering.
  35. sub_band_width : float, default=4.
  36. Frequency bandwidth of filter bank.
  37. sub_band_step : float, default=2.
  38. Step length of each filter bank.
  39. alpha : float, default=0.4
  40. Parameter to define an appropriate threshold for each individual.
  41. tmin, tmax : float, default=0.5, 2.5
  42. Start and end time of the epochs in seconds, relative to
  43. the time-locked event.
  44. picks : str | array_like | slice, default=None
  45. Channels to include. Slices and lists of integers will be
  46. interpreted as channel indices.
  47. If None (default), all channels will pick.
  48. event_id : int | list of int | dict, default=None
  49. Id of the events to consider.
  50. - If dict, the keys can later be used to access associated
  51. events.
  52. - If int, a dict will be created with the id as string.
  53. - If a list, all events with the IDs specified in the list
  54. are used.
  55. - If None, all events will be used and a dict is created
  56. with string integer names corresponding to the event id integers.
  57. cv : cross-validation generator, default=None
  58. An instance of a cross validation iterator from sklearn.
  59. return_class_dis : bool, default=False
  60. Whether to return all_cv_class_dis value.
  61. verbose : bool, str, int, default=None
  62. Control verbosity of the logging output of filtering and .
  63. If None, use the default verbosity level.
  64. Returns
  65. -------
  66. all_cv_best_freq : list
  67. List of the selected frequency band for each hold of
  68. cross validation.
  69. all_cv_class_dis : list, optional
  70. List of class_dis value of each hold of cross validation.
  71. Notes
  72. -----
  73. .. versionadded:: 0.3.1
  74. References
  75. ----------
  76. .. [1] `Class-distinctiveness-based frequency band selection on the
  77. Riemannian manifold for oscillatory activity-based BCIs: preliminary
  78. results
  79. <https://hal.archives-ouvertes.fr/hal-03641137/>`_
  80. M. S. Yamamoto, F. Lotte, F. Yger, and S. Chevallier.
  81. 44th Annual International Conference of the IEEE Engineering
  82. in Medicine & Biology Society (EMBC2022), 2022.
  83. """
  84. subband_fmin = list(np.arange(freq_band[0],
  85. freq_band[1] - sub_band_width + 1.,
  86. sub_band_step))
  87. subband_fmax = list(np.arange(freq_band[0] + sub_band_width,
  88. freq_band[1] + 1., sub_band_step))
  89. n_subband = len(subband_fmin)
  90. all_sub_band_cov = []
  91. for fmin, fmax in zip(subband_fmin, subband_fmax):
  92. cov_data, labels = _get_filtered_cov(raw, picks,
  93. event_id,
  94. fmin,
  95. fmax,
  96. tmin, tmax, verbose)
  97. all_sub_band_cov.append(cov_data)
  98. all_cv_best_freq = []
  99. all_cv_class_dis = []
  100. for i, (train_ind, test_ind) in enumerate(cv.split(all_sub_band_cov[0],
  101. labels)):
  102. all_class_dis = []
  103. for ii in range(n_subband):
  104. class_dis = class_distinctiveness(
  105. all_sub_band_cov[ii][train_ind], labels[train_ind],
  106. exponent=1, metric='riemann', return_num_denom=False)
  107. all_class_dis.append(class_dis)
  108. all_cv_class_dis.append(all_class_dis)
  109. best_freq = _get_best_freq_band(all_class_dis, n_subband,
  110. subband_fmin, subband_fmax, alpha)
  111. all_cv_best_freq.append(best_freq)
  112. if return_class_dis:
  113. return all_cv_best_freq, all_cv_class_dis
  114. else:
  115. return all_cv_best_freq
  116. def _get_filtered_cov(raw, picks, event_id, fmin, fmax, tmin, tmax, verbose):
  117. """Private function to apply band-pass filter and estimate
  118. covariance matrix."""
  119. best_raw_filter = raw.copy().filter(fmin, fmax, method='iir', picks=picks,
  120. verbose=verbose)
  121. events, _ = events_from_annotations(best_raw_filter, event_id=event_id,
  122. verbose=verbose)
  123. epochs = Epochs(
  124. best_raw_filter,
  125. events,
  126. event_id,
  127. tmin,
  128. tmax,
  129. proj=True,
  130. picks=picks,
  131. baseline=None,
  132. preload=True,
  133. verbose=verbose)
  134. labels = epochs.events[:, -1] - 2
  135. epochs_data = epochs.get_data(units="uV")
  136. cov_data = Covariances().transform(epochs_data)
  137. cov_data = Shrinkage().transform(cov_data)
  138. return cov_data, labels
  139. def _get_best_freq_band(all_class_dis, n_subband, subband_fmin, subband_fmax,
  140. alpha):
  141. """Private function to select frequency bands whose class dis value are
  142. higher than the user-specific threshold."""
  143. fmaxstart = np.argmax(all_class_dis)
  144. fmin = np.min(all_class_dis)
  145. fmax = np.max(all_class_dis)
  146. threshold_freq = fmax - (fmax - fmin) * alpha
  147. f0 = fmaxstart
  148. f1 = fmaxstart
  149. while f0 >= 1 and (all_class_dis[f0 - 1] >= threshold_freq):
  150. f0 = f0 - 1
  151. while f1 < n_subband - 1 and (all_class_dis[f1 + 1] >= threshold_freq):
  152. f1 = f1 + 1
  153. best_freq_f0 = subband_fmin[f0]
  154. best_freq_f1 = subband_fmax[f1]
  155. best_freq = [best_freq_f0, best_freq_f1]
  156. return best_freq