feature_extractors.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import numpy as np
  2. from mne import filter
  3. from mne.time_frequency import tfr_array_morlet
  4. from scipy import signal, fftpack
  5. from sklearn.base import BaseEstimator, TransformerMixin
  6. class FilterbankExtractor(BaseEstimator, TransformerMixin):
  7. def __init__(self, sfreq, filter_banks):
  8. self.sfreq = sfreq
  9. self.filter_banks = filter_banks
  10. def fit(self, X, y=None):
  11. return self
  12. def transform(self, X, y=None):
  13. return filterbank_extractor(X, self.sfreq, self.filter_banks, reshape_freqs_dim=True)
  14. def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
  15. n_cycles = filter_banks / 4
  16. power = tfr_array_morlet(data[None],
  17. sfreq=sfreq,
  18. freqs=filter_banks,
  19. n_cycles=n_cycles,
  20. output='avg_power',
  21. verbose=False)
  22. # (n_ch, n_freqs, n_times)
  23. if reshape_freqs_dim:
  24. power = power.reshape((-1, power.shape[-1]))
  25. return power
  26. class FeatExtractor:
  27. def __init__(self, sfreq, lfb_bands, hg_bands):
  28. self.sfreq = sfreq
  29. self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
  30. self.hgs_extractor = HGExtractor(sfreq, hg_bands)
  31. def fit(self, X, y=None):
  32. return self
  33. def transform(self, X):
  34. lfp = self.lfb_extractor.transform(X)
  35. hgs = self.hgs_extractor.transform(X)
  36. return np.concatenate((lfp, hgs), axis=0)
  37. class HGExtractor:
  38. def __init__(self, sfreq, hg_bands):
  39. self.sfreq = sfreq
  40. self.hg_bands = hg_bands
  41. def transform(self, data):
  42. """
  43. data: single trial data (n_ch, n_times)
  44. """
  45. hg_data = []
  46. for b in self.hg_bands:
  47. filter_signal = filter.filter_data(data, self.sfreq, l_freq=b[0], h_freq=b[1], verbose=False, n_jobs=4)
  48. signal_power = np.abs(fast_hilbert(data=filter_signal))
  49. hg_data.append(signal_power)
  50. hg_data = np.concatenate(hg_data, axis=0)
  51. return hg_data
  52. def fast_hilbert(data):
  53. n_signal = data.shape[-1]
  54. fft_length = fftpack.next_fast_len(n_signal)
  55. pad_signal = np.zeros((*data.shape[:-1], fft_length))
  56. pad_signal[..., :n_signal] = data
  57. complex_signal = signal.hilbert(pad_signal, axis=-1)[..., :n_signal]
  58. return complex_signal
  59. class LFPExtractor:
  60. def __init__(self, sfreq, lfb_bands):
  61. self.sfreq = sfreq
  62. self.lfb_bands = lfb_bands
  63. def transform(self, data):
  64. """
  65. data: single trial data (n_ch, n_times)
  66. """
  67. lfp_data = []
  68. for b in self.lfb_bands:
  69. band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero', verbose=False)
  70. lfp_data.append(band_data)
  71. lfp_data = np.concatenate(lfp_data, axis=0)
  72. return lfp_data