feature_extractors.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import numpy as np
  2. from mne import filter
  3. from mne.time_frequency import tfr_array_morlet
  4. from scipy import signal, fftpack
  5. from sklearn.base import BaseEstimator, TransformerMixin
  6. class FilterbankExtractor(BaseEstimator, TransformerMixin):
  7. """
  8. 用于提取滤波器组特征
  9. """
  10. def __init__(self, sfreq, filter_banks):
  11. """
  12. 初始化函数接收两个参数:`sfreq` 和 `filter_banks`。
  13. `sfreq` 是信号的采样频率。
  14. `filter_banks` 是一个包含多个频率的数组,这些频率定义了要应用的滤波器组。
  15. """
  16. self.sfreq = sfreq
  17. self.filter_banks = filter_banks
  18. def fit(self, X, y=None):
  19. """
  20. fit 方法是为了与scikit-learn的接口兼容而定义的。在这种情况下,它不进行任何操作,只是返回实例自身。这是因为特征提取不需要训练过程。
  21. """
  22. return self
  23. def transform(self, X, y=None):
  24. """
  25. transform 方法接收输入数据 X 并使用 filterbank_extractor 函数对其进行变换,然后返回变换后的数据。
  26. 这个方法主要用于将定义的滤波器组应用于输入数据,以提取频率特征。
  27. """
  28. return filterbank_extractor(X, self.sfreq, self.filter_banks, reshape_freqs_dim=True)
  29. def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
  30. """
  31. filterbank_extractor 是一个独立的函数,负责具体的特征提取过程。
  32. data: 输入数据。
  33. sfreq: 采样频率。
  34. filter_banks: 定义了要提取的频率带的数组。
  35. reshape_freqs_dim: 一个布尔值,指定是否要重新塑形频率维度,默认为 False。
  36. 处理步骤
  37. 1. 计算每个滤波器的周期数 n_cycles,这里简单地将 filter_banks 除以4。
  38. 2. 使用 tfr_array_morlet 函数计算数据的时频表示。这个函数应用Morlet小波变换,用于计算指定频率的平均功率。
  39. 3. 默认情况下,输出的功率维度是 (n_ch, n_freqs, n_times)。如果 reshape_freqs_dim 为 True,则将功率数组重塑,以便频率维度和时间维度合并。
  40. """
  41. n_cycles = filter_banks / 4
  42. power = tfr_array_morlet(data[None],
  43. sfreq=sfreq,
  44. freqs=filter_banks,
  45. n_cycles=n_cycles,
  46. output='avg_power',
  47. verbose=False)
  48. # (n_ch, n_freqs, n_times)
  49. if reshape_freqs_dim:
  50. power = power.reshape((-1, power.shape[-1]))
  51. return power
  52. class FeatExtractor:
  53. """
  54. FeatExtractor 是主要的特征提取器类,负责协调低频带(LFB)和高伽马(HG)频带特征的提取。
  55. """
  56. def __init__(self, sfreq, lfb_bands, hg_bands):
  57. """
  58. 初始化函数,设置采样频率和特定频带的参数。
  59. sfreq: 信号的采样频率。
  60. lfb_bands: 低频带参数,如果不为None,则用于LFB特征提取。
  61. hg_bands: 高伽马频带参数,如果不为None,则用于HG特征提取。
  62. 根据 lfb_bands 和 hg_bands 的值,决定是否初始化相应的特征提取器。
  63. """
  64. self.sfreq = sfreq
  65. self.use_lfb = lfb_bands is not None
  66. self.use_hgb = hg_bands is not None
  67. if self.use_lfb:
  68. self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
  69. if self.use_hgb:
  70. self.hgs_extractor = HGExtractor(sfreq, hg_bands)
  71. def fit(self, X, y=None):
  72. """为了与scikit-learn兼容而定义的方法,不进行任何操作,仅返回自身实例。"""
  73. return self
  74. def transform(self, X):
  75. """
  76. 对输入数据 X 进行特征提取。
  77. 如果启用了LFB或HG特征提取,则分别调用相应的提取器,并将特征数组合并。
  78. """
  79. feature = []
  80. if self.use_lfb:
  81. feature.append(self.lfb_extractor.transform(X))
  82. if self.use_hgb:
  83. feature.append(self.hgs_extractor.transform(X))
  84. return np.concatenate(feature, axis=0)
  85. class HGExtractor:
  86. def __init__(self, sfreq, hg_bands):
  87. self.sfreq = sfreq
  88. self.hg_bands = hg_bands
  89. def transform(self, data):
  90. """
  91. data: single trial data (n_ch, n_times)
  92. """
  93. hg_data = []
  94. for b in self.hg_bands:
  95. filter_signal = filter.filter_data(data, self.sfreq, l_freq=b[0], h_freq=b[1], verbose=False, n_jobs=4)
  96. signal_power = np.abs(fast_hilbert(data=filter_signal))
  97. hg_data.append(signal_power)
  98. hg_data = np.concatenate(hg_data, axis=0)
  99. return hg_data
  100. def fast_hilbert(data):
  101. n_signal = data.shape[-1]
  102. fft_length = fftpack.next_fast_len(n_signal)
  103. pad_signal = np.zeros((*data.shape[:-1], fft_length))
  104. pad_signal[..., :n_signal] = data
  105. complex_signal = signal.hilbert(pad_signal, axis=-1)[..., :n_signal]
  106. return complex_signal
  107. class LFPExtractor:
  108. def __init__(self, sfreq, lfb_bands):
  109. self.sfreq = sfreq
  110. self.lfb_bands = lfb_bands
  111. def transform(self, data):
  112. """
  113. data: single trial data (n_ch, n_times)
  114. """
  115. lfp_data = []
  116. for b in self.lfb_bands:
  117. band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero', verbose=False)
  118. lfp_data.append(band_data)
  119. lfp_data = np.concatenate(lfp_data, axis=0)
  120. return lfp_data