frequencybandselection_helpers.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. """
  2. =================================
  3. Frequency Band Selection Helpers
  4. =================================
  5. This file contains helper functions for the frequency band selection example
  6. """
  7. import numpy as np
  8. from mne import Epochs, events_from_annotations
  9. from scipy.interpolate import interp1d
  10. from pyriemann.estimation import Covariances, Shrinkage
  11. from pyriemann.classification import class_distinctiveness
  12. def freq_selection_class_dis(raw, subband_fmin, subband_fmax, alpha=0.4,
  13. band_div=50,
  14. tmin=0.5, tmax=2.5,
  15. picks=None, event_id=None,
  16. cv=None,
  17. return_class_dis=False, verbose=None):
  18. r"""Select optimal frequency band based on class distinctiveness measure.
  19. Optimal frequency band is selected by combining a filter bank with
  20. a heuristic based on class distinctiveness on the manifold [1]_:
  21. 1. Filter training raw EEG data for each sub-band of a filter bank.
  22. 2. Estimate covariance matrices of filtered EEG for each sub-band.
  23. 3. Measure the class distinctiveness of each sub-band using the
  24. classDis metric.
  25. 4. Find the optimal frequency band by starting from the sub-band
  26. with the largest classDis and expanding the selected frequency band
  27. as long as the classDis exceeds the threshold :math:`th`:
  28. .. math::
  29. th = max(classDis) - (max(classDis)−min(classDis)) × alpha
  30. Parameters
  31. ----------
  32. raw : Raw object
  33. An instance of Raw from MNE.
  34. subband_fmin:
  35. List of float that describes the subband lower bound
  36. subband_fmax:
  37. List of float that describes the subband upper bound
  38. alpha : float, default=0.4
  39. Parameter to define an appropriate threshold for each individual.
  40. band_div: float, default=50.
  41. Dividing point of low frequency bands and high frequency bands.
  42. tmin, tmax : float, default=0.5, 2.5
  43. Start and end time of the epochs in seconds, relative to
  44. the time-locked event.
  45. picks : str | array_like | slice, default=None
  46. Channels to include. Slices and lists of integers will be
  47. interpreted as channel indices.
  48. If None (default), all channels will pick.
  49. event_id : int | list of int | dict, default=None
  50. Id of the events to consider.
  51. - If dict, the keys can later be used to access associated
  52. events.
  53. - If int, a dict will be created with the id as string.
  54. - If a list, all events with the IDs specified in the list
  55. are used.
  56. - If None, all events will be used and a dict is created
  57. with string integer names corresponding to the event id integers.
  58. cv : cross-validation generator, default=None
  59. An instance of a cross validation iterator from sklearn.
  60. return_class_dis : bool, default=False
  61. Whether to return all_cv_class_dis value.
  62. verbose : bool, str, int, default=None
  63. Control verbosity of the logging output of filtering and .
  64. If None, use the default verbosity level.
  65. Returns
  66. -------
  67. all_cv_best_freq : list
  68. List of the selected frequency band for each hold of
  69. cross validation.
  70. all_cv_class_dis : list, optional
  71. List of class_dis value of each hold of cross validation.
  72. Notes
  73. -----
  74. .. versionadded:: 0.3.1
  75. References
  76. ----------
  77. .. [1] `Class-distinctiveness-based frequency band selection on the
  78. Riemannian manifold for oscillatory activity-based BCIs: preliminary
  79. results
  80. <https://hal.archives-ouvertes.fr/hal-03641137/>`_
  81. M. S. Yamamoto, F. Lotte, F. Yger, and S. Chevallier.
  82. 44th Annual International Conference of the IEEE Engineering
  83. in Medicine & Biology Society (EMBC2022), 2022.
  84. """
  85. subband_fmean = np.sqrt((subband_fmin * subband_fmax))
  86. n_subband = len(subband_fmin)
  87. low_index = subband_fmean <= band_div
  88. high_index = subband_fmean > band_div
  89. n_subband_low = len(subband_fmean[low_index])
  90. n_subband_high = len(subband_fmean[high_index])
  91. all_sub_band_cov = []
  92. for fmin, fmax in zip(subband_fmin, subband_fmax):
  93. cov_data, labels = _get_filtered_cov(raw, picks,
  94. event_id,
  95. fmin,
  96. fmax,
  97. tmin, tmax, verbose)
  98. all_sub_band_cov.append(cov_data)
  99. all_cv_best_freq = []
  100. all_cv_class_dis = []
  101. for i, (train_ind, test_ind) in enumerate(cv.split(all_sub_band_cov[0],
  102. labels)):
  103. all_class_dis = []
  104. for ii in range(n_subband):
  105. class_dis = class_distinctiveness(
  106. all_sub_band_cov[ii][train_ind], labels[train_ind],
  107. exponent=1, metric='riemann', return_num_denom=False)
  108. all_class_dis.append(class_dis)
  109. all_class_dis = np.array(all_class_dis)
  110. all_cv_class_dis.append(all_class_dis)
  111. best_freq_low = _get_best_freq_band(all_class_dis[low_index], n_subband_low, subband_fmin[low_index], subband_fmax[low_index], alpha)
  112. best_freq_high = _get_best_freq_band(all_class_dis[high_index], n_subband_high, subband_fmin[high_index], subband_fmax[high_index], alpha)
  113. all_cv_best_freq.append((best_freq_low, best_freq_high))
  114. if return_class_dis:
  115. return all_cv_best_freq, all_cv_class_dis
  116. else:
  117. return all_cv_best_freq
  118. def _get_filtered_cov(raw, picks, event_id, fmin, fmax, tmin, tmax, verbose):
  119. """Private function to apply band-pass filter and estimate
  120. covariance matrix."""
  121. best_raw_filter = raw.copy().filter(fmin, fmax, method='iir', picks=picks,
  122. verbose=verbose)
  123. events, _ = events_from_annotations(best_raw_filter, event_id=event_id,
  124. verbose=verbose)
  125. epochs = Epochs(
  126. best_raw_filter,
  127. events,
  128. event_id,
  129. tmin,
  130. tmax,
  131. proj=True,
  132. picks=picks,
  133. baseline=None,
  134. preload=True,
  135. verbose=verbose)
  136. labels = epochs.events[:, -1] - 2
  137. epochs_data = epochs.get_data(units="uV")
  138. cov_data = Covariances().transform(epochs_data)
  139. cov_data = Shrinkage().transform(cov_data)
  140. return cov_data, labels
  141. def _get_best_freq_band(all_class_dis, n_subband, subband_fmin, subband_fmax,
  142. alpha):
  143. """Private function to select frequency bands whose class dis value are
  144. higher than the user-specific threshold."""
  145. fmaxstart = np.argmax(all_class_dis)
  146. fmin = np.min(all_class_dis)
  147. fmax = np.max(all_class_dis)
  148. threshold_freq = fmax - (fmax - fmin) * alpha
  149. f0 = fmaxstart
  150. f1 = fmaxstart
  151. while f0 >= 1 and (all_class_dis[f0 - 1] >= threshold_freq):
  152. f0 = f0 - 1
  153. while f1 < n_subband - 1 and (all_class_dis[f1 + 1] >= threshold_freq):
  154. f1 = f1 + 1
  155. best_freq_f0 = subband_fmin[f0]
  156. best_freq_f1 = subband_fmax[f1]
  157. best_freq = [best_freq_f0, best_freq_f1]
  158. return best_freq