feature_extractors.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import numpy as np
  2. from mne import filter
  3. from mne.time_frequency import tfr_array_morlet
  4. from scipy import signal, fftpack
  5. from sklearn.base import BaseEstimator, TransformerMixin
  6. class FilterbankExtractor(BaseEstimator, TransformerMixin):
  7. def __init__(self, sfreq, filter_banks):
  8. self.sfreq = sfreq
  9. self.filter_banks = filter_banks
  10. def fit(self, X, y=None):
  11. return self
  12. def transform(self, X, y=None):
  13. return filterbank_extractor(X, self.sfreq, self.filter_banks, reshape_freqs_dim=True)
  14. def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
  15. n_cycles = filter_banks / 4
  16. power = tfr_array_morlet(data[None],
  17. sfreq=sfreq,
  18. freqs=filter_banks,
  19. n_cycles=n_cycles,
  20. output='avg_power',
  21. verbose=False)
  22. # (n_ch, n_freqs, n_times)
  23. if reshape_freqs_dim:
  24. power = power.reshape((-1, power.shape[-1]))
  25. return power
  26. class FeatExtractor:
  27. def __init__(self, sfreq, lfb_bands, hg_bands):
  28. self.sfreq = sfreq
  29. self.use_lfb = lfb_bands is not None
  30. self.use_hgb = hg_bands is not None
  31. if self.use_lfb:
  32. self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
  33. if self.use_hgb:
  34. self.hgs_extractor = HGExtractor(sfreq, hg_bands)
  35. def fit(self, X, y=None):
  36. return self
  37. def transform(self, X):
  38. feature = []
  39. if self.use_lfb:
  40. feature.append(self.lfb_extractor.transform(X))
  41. if self.use_hgb:
  42. feature.append(self.hgs_extractor.transform(X))
  43. return np.concatenate(feature, axis=0)
  44. class HGExtractor:
  45. def __init__(self, sfreq, hg_bands):
  46. self.sfreq = sfreq
  47. self.hg_bands = hg_bands
  48. def transform(self, data):
  49. """
  50. data: single trial data (n_ch, n_times)
  51. """
  52. hg_data = []
  53. for b in self.hg_bands:
  54. filter_signal = filter.filter_data(data, self.sfreq, l_freq=b[0], h_freq=b[1], verbose=False, n_jobs=4)
  55. signal_power = np.abs(fast_hilbert(data=filter_signal))
  56. hg_data.append(signal_power)
  57. hg_data = np.concatenate(hg_data, axis=0)
  58. return hg_data
  59. def fast_hilbert(data):
  60. n_signal = data.shape[-1]
  61. fft_length = fftpack.next_fast_len(n_signal)
  62. pad_signal = np.zeros((*data.shape[:-1], fft_length))
  63. pad_signal[..., :n_signal] = data
  64. complex_signal = signal.hilbert(pad_signal, axis=-1)[..., :n_signal]
  65. return complex_signal
  66. class LFPExtractor:
  67. def __init__(self, sfreq, lfb_bands):
  68. self.sfreq = sfreq
  69. self.lfb_bands = lfb_bands
  70. def transform(self, data):
  71. """
  72. data: single trial data (n_ch, n_times)
  73. """
  74. lfp_data = []
  75. for b in self.lfb_bands:
  76. band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero', verbose=False)
  77. lfp_data.append(band_data)
  78. lfp_data = np.concatenate(lfp_data, axis=0)
  79. return lfp_data