|
@@ -8,13 +8,13 @@ from sklearn.base import BaseEstimator, TransformerMixin
|
|
|
class FilterbankExtractor(BaseEstimator, TransformerMixin):
|
|
|
def __init__(self, sfreq, filter_banks):
|
|
|
self.sfreq = sfreq
|
|
|
- self.freqs = filter_banks
|
|
|
+ self.filter_banks = filter_banks
|
|
|
|
|
|
def fit(self, X, y=None):
|
|
|
return self
|
|
|
|
|
|
def transform(self, X, y=None):
|
|
|
- return filterbank_extractor(X, self.sfreq, self.freqs, reshape_freqs_dim=True)
|
|
|
+ return filterbank_extractor(X, self.sfreq, self.filter_banks, reshape_freqs_dim=True)
|
|
|
|
|
|
|
|
|
def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
|