Browse Source

Feat: basic riemann model

dk 1 year ago
parent
commit
dbd64f694d
4 changed files with 19 additions and 37 deletions
  1. 3 22
      backend/bci_core/model.py
  2. 4 1
      backend/bci_core/online.py
  3. 11 13
      backend/training.py
  4. 1 1
      backend/validation.py

+ 3 - 22
backend/bci_core/model.py

@@ -1,10 +1,9 @@
 import numpy as np
 
 from sklearn.linear_model import LogisticRegression
-from pyriemann.estimation import Covariances, BlockCovariances
 from pyriemann.tangentspace import TangentSpace
+from pyriemann.preprocessing import Whitening
 
-from sklearn.ensemble import StackingClassifier
 from sklearn.pipeline import make_pipeline
 from sklearn.base import BaseEstimator, TransformerMixin
 
@@ -51,27 +50,9 @@ class ChannelScaler(BaseEstimator, TransformerMixin):
         return X
 
 
-def stacking_riemann(lfb_dim, hgs_dim, C_lfb=1., C_hgs=1.):
-    clf_lfb = make_pipeline(
-        FeatureSelector('lfb', lfb_dim, hgs_dim),
-        TangentSpace(),
-        LogisticRegression(C=C_lfb)
-    )
-    clf_hgs = make_pipeline(
-        FeatureSelector('hgs', lfb_dim, hgs_dim),
-        TangentSpace(),
-        LogisticRegression(C=C_hgs)
-    )
-    sclf = StackingClassifier(
-        estimators=[('clf_lfb', clf_lfb), ('clf_hgs', clf_hgs)],
-        final_estimator=LogisticRegression(), n_jobs=2
-    )
-    return sclf
-
-
-def one_stage_riemann(C=1.):
+def riemann_model(C=1.):
     return make_pipeline(
-        Covariances(estimator='lwf'),
+        Whitening(metric='euclid', dim_red={'expl_var': 0.99}),
         TangentSpace(),
         LogisticRegression(C=C)
     )

+ 4 - 1
backend/bci_core/online.py

@@ -27,6 +27,8 @@ class Controller:
             self.model_type, _ = parse_model_type(model_path)
             if self.model_type == 'baseline':
                 self.real_feedback_model = BaselineHMM(model_path, state_change_threshold=state_change_threshold)
+            elif self.model_type == 'riemann':
+                self.real_feedback_model = RiemannHMM(model_path, state_change_threshold=state_change_threshold)
             else:
                 raise NotImplementedError
         self.virtual_feedback_rate = virtual_feedback_rate
@@ -208,10 +210,11 @@ class RiemannHMM(HMMModel):
         """
         data = super(RiemannHMM, self).step_probability(fs, data)
         data = self.feat_extractor.transform(data)
+        data = data[None]  # pad trial dimension
         # scale data
         data = self.scaler.transform(data)
         # compute cov
         data = self.cov.transform(data)
         # predict proba
-        p = self.model.predict_proba(data[None]).squeeze()
+        p = self.model.predict_proba(data).squeeze()
         return p

+ 11 - 13
backend/training.py

@@ -2,7 +2,6 @@ import logging
 import joblib
 import os
 from datetime import datetime
-from functools import partial
 import yaml
 
 import mne
@@ -19,27 +18,27 @@ from dataloaders import neo
 logging.basicConfig(level=logging.INFO)
 
 
-def train_model(raw, event_id, model_type='baseline'):
+def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
     """
     """
     events, _ = mne.events_from_annotations(raw, event_id=event_id)
     if model_type.lower() == 'baseline':
-        model = _train_baseline_model(raw, events)
+        model = _train_baseline_model(raw, events, duration=trial_duration)
     elif model_type.lower() == 'riemann':
         # TODO: load subject config
-        model = _train_riemann_model(raw, events)
+        model = _train_riemann_model(raw, events, duration=trial_duration)
     else:
         raise NotImplementedError
     return model
 
 
-def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)]):
+def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
     fs = raw.info['sfreq']
     n_ch = len(raw.ch_names)
     feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
     filtered_data = feat_extractor.transform(raw.get_data())
     # TODO: find proper latency
-    X = bci_utils.cut_epochs((0, 1., fs), filtered_data, events[:, 0])
+    X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
     y = events[:, -1]
     
     scaler = bci_model.ChannelScaler()
@@ -51,24 +50,23 @@ def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[
     cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
     X_cov = cov_model.fit_transform(X)
 
-    param = {'C_lfb': np.logspace(-4, 0, 5), 'C_hgs': np.logspace(-3, 1, 5)}
+    param = {'C': np.logspace(-5, 4, 10)}
 
-    model_func = partial(bci_model.stacking_riemann, lfb_dim=lfb_dim, hgs_dim=hgs_dim)
-    best_auc, best_param = bci_utils.param_search(model_func, X_cov, y, param)
+    best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
 
     logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
 
     # train and dump best model
-    model_to_train = model_func(**best_param)
+    model_to_train = bci_model.riemann_model(**best_param)
     model_to_train.fit(X_cov, y)
     return [feat_extractor, scaler, cov_model, model_to_train]
 
 
-def _train_baseline_model(raw, events):
+def _train_baseline_model(raw, events, duration=1., ):
     fs = raw.info['sfreq']
     filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
 
-    filter_bank_epoch = bci_utils.cut_epochs((0, 1., fs), filter_bank_data, events[:, 0])
+    filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
 
     # downsampling to 10 Hz
     # decim 2 times, to 100Hz
@@ -108,7 +106,7 @@ def model_saver(model, model_path, model_type, subject_id, event_id):
 if __name__ == '__main__':
     # TODO: argparse
     subj_name = 'XW01'
-    model_type = 'baseline'
+    model_type = 'riemann'
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/'

+ 1 - 1
backend/validation.py

@@ -105,7 +105,7 @@ if __name__ == '__main__':
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/'
-    model_path = f'./static/models/{subj_name}/baseline_rest+flex_11-20-2023-19-26-37.pkl'
+    model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-16-43-23.pkl'
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']