model.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. from sklearn.linear_model import LogisticRegression
  3. from pyriemann.estimation import Covariances, BlockCovariances
  4. from pyriemann.tangentspace import TangentSpace
  5. from sklearn.ensemble import StackingClassifier
  6. from sklearn.pipeline import make_pipeline
  7. from sklearn.base import BaseEstimator, TransformerMixin
  8. from mne.decoding import Vectorizer
  9. class FeatureSelector(BaseEstimator, TransformerMixin):
  10. def __init__(self, feature_type, lfb_dim, hgs_dim):
  11. self.feature_type = feature_type
  12. self.lfb_dim = lfb_dim
  13. self.hgs_dim = hgs_dim
  14. def fit(self, X, y=None):
  15. return self
  16. def transform(self, X, y=None):
  17. if self.feature_type == 'lfb':
  18. return X[:, 0:self.lfb_dim, 0:self.lfb_dim].copy()
  19. else:
  20. return X[:, self.lfb_dim:self.lfb_dim + self.hgs_dim, self.lfb_dim:self.lfb_dim + self.hgs_dim].copy()
  21. class ChannelScaler(BaseEstimator, TransformerMixin):
  22. def __init__(self, norm_axis=(0, 2)):
  23. self.channel_mean_ = None
  24. self.channel_std_ = None
  25. self.norm_axis=norm_axis
  26. def fit(self, X, y=None):
  27. '''
  28. :param X: 3d array with shape (n_epochs, n_channels, n_times)
  29. :param y:
  30. :return:
  31. '''
  32. self.channel_mean_ = np.mean(X, axis=self.norm_axis, keepdims=True)
  33. self.channel_std_ = np.std(X, axis=self.norm_axis, keepdims=True)
  34. return self
  35. def transform(self, X, y=None):
  36. X = X.copy()
  37. X -= self.channel_mean_
  38. X /= self.channel_std_
  39. return X
  40. def stacking_riemann(lfb_dim, hgs_dim, C_lfb=1., C_hgs=1.):
  41. clf_lfb = make_pipeline(
  42. FeatureSelector('lfb', lfb_dim, hgs_dim),
  43. TangentSpace(),
  44. LogisticRegression(C=C_lfb)
  45. )
  46. clf_hgs = make_pipeline(
  47. FeatureSelector('hgs', lfb_dim, hgs_dim),
  48. TangentSpace(),
  49. LogisticRegression(C=C_hgs)
  50. )
  51. sclf = StackingClassifier(
  52. estimators=[('clf_lfb', clf_lfb), ('clf_hgs', clf_hgs)],
  53. final_estimator=LogisticRegression(), n_jobs=2
  54. )
  55. return sclf
  56. def one_stage_riemann(C=1.):
  57. return make_pipeline(
  58. Covariances(estimator='lwf'),
  59. TangentSpace(),
  60. LogisticRegression(C=C)
  61. )
  62. def baseline_model(C=1.):
  63. return make_pipeline(
  64. ChannelScaler(),
  65. Vectorizer(),
  66. LogisticRegression(C=C)
  67. )