import numpy as np from mne import filter from mne.time_frequency import tfr_array_morlet from scipy import signal, fftpack from sklearn.base import BaseEstimator, TransformerMixin class FilterbankExtractor(BaseEstimator, TransformerMixin): def __init__(self, sfreq, filter_banks): self.sfreq = sfreq self.freqs = filter_banks def fit(self, X, y=None): return self def transform(self, X, y=None): return filterbank_extractor(X, self.sfreq, self.freqs, reshape_freqs_dim=True) def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False): n_cycles = filter_banks / 4 power = tfr_array_morlet(data[None], sfreq=sfreq, freqs=filter_banks, n_cycles=n_cycles, output='avg_power', verbose=False) # (n_ch, n_freqs, n_times) if reshape_freqs_dim: power = power.reshape((-1, power.shape[-1])) return power class FeatExtractor: def __init__(self, sfreq, lfb_bands, hg_bands): self.sfreq = sfreq self.lfb_extractor = LFPExtractor(sfreq, lfb_bands) self.hgs_extractor = HGExtractor(sfreq, hg_bands) def fit(self, X, y=None): return self def transform(self, X): lfp = self.lfb_extractor.transform(X) hgs = self.hgs_extractor.transform(X) return np.concatenate((lfp, hgs), axis=0) class HGExtractor: def __init__(self, sfreq, hg_bands): self.sfreq = sfreq self.hg_bands = hg_bands def transform(self, data): """ data: single trial data (n_ch, n_times) """ hg_data = [] for b in self.hg_bands: filter_signal = filter.filter_data(data, self.sfreq, l_freq=b[0], h_freq=b[1], verbose=False, n_jobs=4) signal_power = np.abs(fast_hilbert(data=filter_signal)) hg_data.append(signal_power) hg_data = np.concatenate(hg_data, axis=0) return hg_data def fast_hilbert(data): n_signal = data.shape[-1] fft_length = fftpack.next_fast_len(n_signal) pad_signal = np.zeros((*data.shape[:-1], fft_length)) pad_signal[..., :n_signal] = data complex_signal = signal.hilbert(pad_signal, axis=-1)[..., :n_signal] return complex_signal class LFPExtractor: def __init__(self, sfreq, lfb_bands): self.sfreq = sfreq self.lfb_bands = lfb_bands def transform(self, data): """ data: single trial data (n_ch, n_times) """ lfp_data = [] for b in self.lfb_bands: band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero') lfp_data.append(band_data) lfp_data = np.concatenate(lfp_data, axis=0) return lfp_data