|
@@ -34,16 +34,23 @@ def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
|
|
|
class FeatExtractor:
|
|
|
def __init__(self, sfreq, lfb_bands, hg_bands):
|
|
|
self.sfreq = sfreq
|
|
|
- self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
|
|
|
- self.hgs_extractor = HGExtractor(sfreq, hg_bands)
|
|
|
+ self.use_lfb = lfb_bands is not None
|
|
|
+ self.use_hgb = hg_bands is not None
|
|
|
+ if self.use_lfb:
|
|
|
+ self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
|
|
|
+ if self.use_hgb:
|
|
|
+ self.hgs_extractor = HGExtractor(sfreq, hg_bands)
|
|
|
|
|
|
def fit(self, X, y=None):
|
|
|
return self
|
|
|
|
|
|
def transform(self, X):
|
|
|
- lfp = self.lfb_extractor.transform(X)
|
|
|
- hgs = self.hgs_extractor.transform(X)
|
|
|
- return np.concatenate((lfp, hgs), axis=0)
|
|
|
+ feature = []
|
|
|
+ if self.use_lfb:
|
|
|
+ feature.append(self.lfb_extractor.transform(X))
|
|
|
+ if self.use_hgb:
|
|
|
+ feature.append(self.hgs_extractor.transform(X))
|
|
|
+ return np.concatenate(feature, axis=0)
|
|
|
|
|
|
|
|
|
class HGExtractor:
|