Browse Source

Feat: 支持只是用low frequency或者high frequency

dk 1 year ago
parent
commit
489ab30343
2 changed files with 18 additions and 8 deletions
  1. 12 5
      backend/bci_core/feature_extractors.py
  2. 6 3
      backend/training.py

+ 12 - 5
backend/bci_core/feature_extractors.py

@@ -34,16 +34,23 @@ def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):
 class FeatExtractor:
     def __init__(self, sfreq, lfb_bands, hg_bands):
         self.sfreq = sfreq
-        self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
-        self.hgs_extractor = HGExtractor(sfreq, hg_bands)
+        self.use_lfb = lfb_bands is not None
+        self.use_hgb = hg_bands is not None
+        if self.use_lfb:
+            self.lfb_extractor = LFPExtractor(sfreq, lfb_bands)
+        if self.use_hgb:
+            self.hgs_extractor = HGExtractor(sfreq, hg_bands)
 
     def fit(self, X, y=None):
         return self
     
     def transform(self, X):
-        lfp = self.lfb_extractor.transform(X)
-        hgs = self.hgs_extractor.transform(X)
-        return np.concatenate((lfp, hgs), axis=0)
+        feature = []
+        if self.use_lfb:
+            feature.append(self.lfb_extractor.transform(X))
+        if self.use_hgb:
+            feature.append(self.hgs_extractor.transform(X))
+        return np.concatenate(feature, axis=0)
 
 
 class HGExtractor:

+ 6 - 3
backend/training.py

@@ -68,9 +68,12 @@ def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)]
     X = scaler.fit_transform(X)
 
     # compute covariance
-    lfb_dim = len(lf_bands) * n_ch
-    hgs_dim = len(hg_bands) * n_ch
-    cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
+    feat_dim = []
+    if lf_bands is not None:
+        feat_dim.append(len(lf_bands) * n_ch)
+    if hg_bands is not None:
+        feat_dim.append(len(hg_bands) * n_ch)
+    cov_model = BlockCovariances(feat_dim, estimator='lwf')
     X_cov = cov_model.fit_transform(X)
 
     param = {'C': np.logspace(-5, 4, 10)}