model.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import numpy as np
  2. from sklearn.linear_model import LogisticRegression
  3. from pyriemann.tangentspace import TangentSpace
  4. from pyriemann.preprocessing import Whitening
  5. from sklearn.pipeline import make_pipeline
  6. from sklearn.base import BaseEstimator, TransformerMixin
  7. from mne.decoding import Vectorizer
  8. class FeatureSelector(BaseEstimator, TransformerMixin):
  9. def __init__(self, feature_type, lfb_dim, hgs_dim):
  10. self.feature_type = feature_type
  11. self.lfb_dim = lfb_dim
  12. self.hgs_dim = hgs_dim
  13. def fit(self, X, y=None):
  14. return self
  15. def transform(self, X, y=None):
  16. if self.feature_type == 'lfb':
  17. return X[:, 0:self.lfb_dim, 0:self.lfb_dim].copy()
  18. else:
  19. return X[:, self.lfb_dim:self.lfb_dim + self.hgs_dim, self.lfb_dim:self.lfb_dim + self.hgs_dim].copy()
  20. class ChannelScaler(BaseEstimator, TransformerMixin):
  21. def __init__(self, norm_axis=(0, 2)):
  22. self.channel_mean_ = None
  23. self.channel_std_ = None
  24. self.norm_axis=norm_axis
  25. def fit(self, X, y=None):
  26. '''
  27. :param X: 3d array with shape (n_epochs, n_channels, n_times)
  28. :param y:
  29. :return:
  30. '''
  31. self.channel_mean_ = np.mean(X, axis=self.norm_axis, keepdims=True)
  32. self.channel_std_ = np.std(X, axis=self.norm_axis, keepdims=True)
  33. return self
  34. def transform(self, X, y=None):
  35. X = X.copy()
  36. X -= self.channel_mean_
  37. X /= self.channel_std_
  38. return X
  39. def riemann_model(C=1.):
  40. return make_pipeline(
  41. Whitening(metric='riemann', dim_red={'expl_var': 0.99}),
  42. TangentSpace(),
  43. LogisticRegression(C=C)
  44. )
  45. def baseline_model(C=1.):
  46. return make_pipeline(
  47. ChannelScaler(),
  48. Vectorizer(),
  49. LogisticRegression(C=C)
  50. )