فهرست منبع

Fix: 根据train变化修改测试

dk 1 سال پیش
والد
کامیت
3fcbb17a88
2فایلهای تغییر یافته به همراه3 افزوده شده و 3 حذف شده
  1. 2 2
      backend/bci_core/feature_extractors.py
  2. 1 1
      backend/tests/test_training.py

+ 2 - 2
backend/bci_core/feature_extractors.py

@@ -8,13 +8,13 @@ from sklearn.base import BaseEstimator, TransformerMixin
 class FilterbankExtractor(BaseEstimator, TransformerMixin):
     def __init__(self, sfreq, filter_banks):
         self.sfreq = sfreq
-        self.freqs = filter_banks
+        self.filter_banks = filter_banks
     
     def fit(self, X, y=None):
         return self
     
     def transform(self, X, y=None):
-        return filterbank_extractor(X, self.sfreq, self.freqs, reshape_freqs_dim=True)
+        return filterbank_extractor(X, self.sfreq, self.filter_banks, reshape_freqs_dim=True)
 
 
 def filterbank_extractor(data, sfreq, filter_banks, reshape_freqs_dim=False):

+ 1 - 1
backend/tests/test_training.py

@@ -25,7 +25,7 @@ class TestTraining(unittest.TestCase):
     
     def test_training_baseline(self):
         model = training.train_model(self.raw, self.event_id, model_type='baseline')
-        check_is_fitted(model)
+        check_is_fitted(model[1])
 
     def test_saver(self):
         feat_ext = FeatExtractor(1000, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)])