import numpy as np from sklearn.linear_model import LogisticRegression from pyriemann.estimation import Covariances, BlockCovariances from pyriemann.tangentspace import TangentSpace from sklearn.ensemble import StackingClassifier from sklearn.preprocessing import FunctionTransformer from sklearn.pipeline import make_pipeline from sklearn.base import BaseEstimator, TransformerMixin from sklearn.preprocessing import StandardScaler from mne.decoding import Vectorizer class FeatureSelector(BaseEstimator, TransformerMixin): def __init__(self, feature_type, lfb_dim, hgs_dim): self.feature_type = feature_type self.lfb_dim = lfb_dim self.hgs_dim = hgs_dim def fit(self, X, y=None): return self def transform(self, X, y=None): if self.feature_type == 'lfb': return X[:, 0:self.lfb_dim, 0:self.lfb_dim].copy() else: return X[:, self.lfb_dim:self.lfb_dim + self.hgs_dim, self.lfb_dim:self.lfb_dim + self.hgs_dim].copy() class ChannelScaler(BaseEstimator, TransformerMixin): def __init__(self, norm_axis=(0, 2)): self.channel_mean_ = None self.channel_std_ = None self.norm_axis=norm_axis def fit(self, X, y=None): ''' :param X: 3d array with shape (n_epochs, n_channels, n_times) :param y: :return: ''' self.channel_mean_ = np.mean(X, axis=self.norm_axis, keepdims=True) self.channel_std_ = np.std(X, axis=self.norm_axis, keepdims=True) return self def transform(self, X, y=None): X = X.copy() X -= self.channel_mean_ X /= self.channel_std_ return X def stacking_riemann(lfb_dim, hgs_dim, C_lfb=1., C_hgs=1.): clf_lfb = make_pipeline( FeatureSelector('lfb', lfb_dim, hgs_dim), TangentSpace(), LogisticRegression(C=C_lfb) ) clf_hgs = make_pipeline( FeatureSelector('hgs', lfb_dim, hgs_dim), TangentSpace(), LogisticRegression(C=C_hgs) ) sclf = StackingClassifier( estimators=[('clf_lfb', clf_lfb), ('clf_hgs', clf_hgs)], final_estimator=LogisticRegression(), n_jobs=2 ) return sclf def one_stage_riemann(C=1.): return make_pipeline( Covariances(estimator='lwf'), TangentSpace(), LogisticRegression(C=C) ) def baseline_model(C=1.): return make_pipeline( ChannelScaler(), Vectorizer(), LogisticRegression(C=C) )