123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import numpy as np
- from mne import filter
- from mne.time_frequency import tfr_array_morlet
- from scipy import signal, fftpack
- from sklearn.base import BaseEstimator, TransformerMixin
- class FilterbankExtractor(BaseEstimator, TransformerMixin):
- def __init__(self, sfreq, filter_banks):
- self.sfreq = sfreq
- self.filter_banks = filter_banks
-
- def fit(self, X, y=None):
- return self
-
- def transform(self, X, y=None):
- return filterbank_extractor(X, self.sfreq, self.filter_banks, reshape_freqs_dim=True)
- def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
- n_cycles = filter_banks / 4
- power = tfr_array_morlet(data[None],
- sfreq=sfreq,
- freqs=filter_banks,
- n_cycles=n_cycles,
- output='avg_power',
- verbose=False)
- # (n_ch, n_freqs, n_times)
- if reshape_freqs_dim:
- power = power.reshape((-1, power.shape[-1]))
- return power
- class FeatExtractor:
- def __init__(self, sfreq, lfb_bands, hg_bands):
- self.sfreq = sfreq
- self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
- self.hgs_extractor = HGExtractor(sfreq, hg_bands)
- def fit(self, X, y=None):
- return self
-
- def transform(self, X):
- lfp = self.lfb_extractor.transform(X)
- hgs = self.hgs_extractor.transform(X)
- return np.concatenate((lfp, hgs), axis=0)
- class HGExtractor:
- def __init__(self, sfreq, hg_bands):
- self.sfreq = sfreq
- self.hg_bands = hg_bands
- def transform(self, data):
- """
- data: single trial data (n_ch, n_times)
- """
- hg_data = []
- for b in self.hg_bands:
- filter_signal = filter.filter_data(data, self.sfreq, l_freq=b[0], h_freq=b[1], verbose=False, n_jobs=4)
- signal_power = np.abs(fast_hilbert(data=filter_signal))
- hg_data.append(signal_power)
- hg_data = np.concatenate(hg_data, axis=0)
- return hg_data
-
- def fast_hilbert(data):
- n_signal = data.shape[-1]
- fft_length = fftpack.next_fast_len(n_signal)
- pad_signal = np.zeros((*data.shape[:-1], fft_length))
- pad_signal[..., :n_signal] = data
- complex_signal = signal.hilbert(pad_signal, axis=-1)[..., :n_signal]
- return complex_signal
- class LFPExtractor:
- def __init__(self, sfreq, lfb_bands):
- self.sfreq = sfreq
- self.lfb_bands = lfb_bands
- def transform(self, data):
- """
- data: single trial data (n_ch, n_times)
- """
- lfp_data = []
- for b in self.lfb_bands:
- band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero', verbose=False)
- lfp_data.append(band_data)
- lfp_data = np.concatenate(lfp_data, axis=0)
- return lfp_data
|