model.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import numpy as np
  2. from sklearn.linear_model import LogisticRegression
  3. from pyriemann.estimation import Covariances, BlockCovariances
  4. from pyriemann.tangentspace import TangentSpace
  5. from sklearn.ensemble import StackingClassifier
  6. from sklearn.preprocessing import FunctionTransformer
  7. from sklearn.pipeline import make_pipeline
  8. from sklearn.base import BaseEstimator, TransformerMixin
  9. from sklearn.preprocessing import StandardScaler
  10. from mne.decoding import Vectorizer
  11. class FeatureSelector(BaseEstimator, TransformerMixin):
  12. def __init__(self, feature_type, lfb_dim, hgs_dim):
  13. self.feature_type = feature_type
  14. self.lfb_dim = lfb_dim
  15. self.hgs_dim = hgs_dim
  16. def fit(self, X, y=None):
  17. return self
  18. def transform(self, X, y=None):
  19. if self.feature_type == 'lfb':
  20. return X[:, 0:self.lfb_dim, 0:self.lfb_dim].copy()
  21. else:
  22. return X[:, self.lfb_dim:self.lfb_dim + self.hgs_dim, self.lfb_dim:self.lfb_dim + self.hgs_dim].copy()
  23. class ChannelScaler(BaseEstimator, TransformerMixin):
  24. def __init__(self, norm_axis=(0, 2)):
  25. self.channel_mean_ = None
  26. self.channel_std_ = None
  27. self.norm_axis=norm_axis
  28. def fit(self, X, y=None):
  29. '''
  30. :param X: 3d array with shape (n_epochs, n_channels, n_times)
  31. :param y:
  32. :return:
  33. '''
  34. self.channel_mean_ = np.mean(X, axis=self.norm_axis, keepdims=True)
  35. self.channel_std_ = np.std(X, axis=self.norm_axis, keepdims=True)
  36. return self
  37. def transform(self, X, y=None):
  38. X = X.copy()
  39. X -= self.channel_mean_
  40. X /= self.channel_std_
  41. return X
  42. def stacking_riemann(lfb_dim, hgs_dim, C_lfb=1., C_hgs=1.):
  43. clf_lfb = make_pipeline(
  44. FeatureSelector('lfb', lfb_dim, hgs_dim),
  45. TangentSpace(),
  46. LogisticRegression(C=C_lfb)
  47. )
  48. clf_hgs = make_pipeline(
  49. FeatureSelector('hgs', lfb_dim, hgs_dim),
  50. TangentSpace(),
  51. LogisticRegression(C=C_hgs)
  52. )
  53. sclf = StackingClassifier(
  54. estimators=[('clf_lfb', clf_lfb), ('clf_hgs', clf_hgs)],
  55. final_estimator=LogisticRegression(), n_jobs=2
  56. )
  57. return sclf
  58. def one_stage_riemann(C=1.):
  59. return make_pipeline(
  60. Covariances(estimator='lwf'),
  61. TangentSpace(),
  62. LogisticRegression(C=C)
  63. )
  64. def baseline_model(C=1.):
  65. return make_pipeline(
  66. ChannelScaler(),
  67. Vectorizer(),
  68. LogisticRegression(C=C)
  69. )