Browse Source

Feat: take freqs as input (for baseline model)

dk 1 year ago
parent
commit
6b936ad854
2 changed files with 12 additions and 11 deletions
  1. 7 8
      backend/bci_core/online.py
  2. 5 3
      backend/training.py

+ 7 - 8
backend/bci_core/online.py

@@ -177,10 +177,8 @@ class HMMModel:
 class BaselineHMM(HMMModel):
     def __init__(self, model, **kwargs):
         if isinstance(model, str):
-            self.model = joblib.load(model)
-        else:
-            self.model = model
-        self.freqs = np.arange(20, 150, 15)
+            model = joblib.load(model)
+        self.feat_extractor, self.model = model
 
         super(BaselineHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
     
@@ -189,7 +187,7 @@ class BaselineHMM(HMMModel):
         """
         data = super(BaselineHMM, self).step_probability(fs, data)
         # filter data
-        filter_bank_data = filterbank_extractor(data, fs, self.freqs, reshape_freqs_dim=True)
+        filter_bank_data = self.feat_extractor.transform(data)
         # downsampling
         decimate_rate = np.sqrt(fs / 10).astype(np.int16)
         filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True)
@@ -202,9 +200,10 @@ class BaselineHMM(HMMModel):
 class RiemannHMM(HMMModel):
     def __init__(self, model, **kwargs):
         if isinstance(model, str):
-            self.feat_extractor, self.scaler, self.cov, self.model = joblib.load(model)
-        else:
-            self.feat_extractor, self.scaler, self.cov, self.model = model
+            model = joblib.load(model)
+
+        self.feat_extractor, self.scaler, self.cov, self.model = model
+
         super(RiemannHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
 
     def step_probability(self, fs, data):

+ 5 - 3
backend/training.py

@@ -85,9 +85,11 @@ def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)]
     return [feat_extractor, scaler, cov_model, model_to_train]
 
 
-def _train_baseline_model(raw, events, duration=1., ):
+def _train_baseline_model(raw, events, duration=1., freqs=(20, 150, 15)):
     fs = raw.info['sfreq']
-    filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
+    freqs = np.arange(*freqs)
+    filterbank_extractor = feature_extractors.FilterbankExtractor(fs, freqs)
+    filter_bank_data = filterbank_extractor.transform(raw.get_data())
 
     filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
 
@@ -105,7 +107,7 @@ def _train_baseline_model(raw, events, duration=1., ):
 
     model_to_train = bci_model.baseline_model(**best_param)
     model_to_train.fit(X, y)
-    return model_to_train
+    return filterbank_extractor, model_to_train
 
 
 def model_saver(model, model_path, model_type, subject_id, event_id):