import numpy as np
from mne import filter
from mne.time_frequency import tfr_array_morlet
from scipy import signal, fftpack
from sklearn.base import BaseEstimator, TransformerMixin


class FilterbankExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, sfreq, filter_banks):
        self.sfreq = sfreq
        self.filter_banks = filter_banks
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        return filterbank_extractor(X, self.sfreq, self.filter_banks, reshape_freqs_dim=True)


def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
    n_cycles = filter_banks / 4
    power = tfr_array_morlet(data[None],
                            sfreq=sfreq,
                            freqs=filter_banks,
                            n_cycles=n_cycles,
                            output='avg_power',
                            verbose=False)
    # (n_ch, n_freqs, n_times)
    if reshape_freqs_dim:
        power = power.reshape((-1, power.shape[-1]))
    return power


class FeatExtractor:
    def __init__(self, sfreq, lfb_bands, hg_bands):
        self.sfreq = sfreq
        self.use_lfb = lfb_bands is not None
        self.use_hgb = hg_bands is not None
        if self.use_lfb:
            self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
        if self.use_hgb:
            self.hgs_extractor = HGExtractor(sfreq, hg_bands)

    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        feature = []
        if self.use_lfb:
            feature.append(self.lfb_extractor.transform(X))
        if self.use_hgb:
            feature.append(self.hgs_extractor.transform(X))
        return np.concatenate(feature, axis=0)


class HGExtractor:
    def __init__(self, sfreq, hg_bands):
        self.sfreq = sfreq
        self.hg_bands = hg_bands

    def transform(self, data):
        """
        data: single trial data (n_ch, n_times)
        """
        hg_data = []
        for b in self.hg_bands:
            filter_signal = filter.filter_data(data, self.sfreq, l_freq=b[0], h_freq=b[1], verbose=False, n_jobs=4)
            signal_power = np.abs(fast_hilbert(data=filter_signal))
            hg_data.append(signal_power)
        hg_data = np.concatenate(hg_data, axis=0)
        return hg_data
        

def fast_hilbert(data):
    n_signal = data.shape[-1]
    fft_length = fftpack.next_fast_len(n_signal)
    pad_signal = np.zeros((*data.shape[:-1], fft_length))
    pad_signal[..., :n_signal] = data
    complex_signal = signal.hilbert(pad_signal, axis=-1)[..., :n_signal]
    return complex_signal


class LFPExtractor:
    def __init__(self, sfreq, lfb_bands):
        self.sfreq = sfreq
        self.lfb_bands = lfb_bands

    def transform(self, data):
        """
        data: single trial data (n_ch, n_times)
        """
        lfp_data = []
        for b in self.lfb_bands:
            band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero', verbose=False)
            lfp_data.append(band_data)
        lfp_data = np.concatenate(lfp_data, axis=0)
        return lfp_data