model.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import numpy as np
  2. from scipy import signal
  3. from pyriemann.estimation import BlockCovariances
  4. from pyriemann.tangentspace import TangentSpace
  5. from pyriemann.preprocessing import Whitening
  6. from sklearn.pipeline import make_pipeline
  7. from sklearn.base import BaseEstimator, TransformerMixin
  8. from mne.decoding import Vectorizer, CSP
  9. class DecimateFeature(BaseEstimator, TransformerMixin):
  10. """DecimateFeature 类用于对信号进行降采样以达到目标采样频率。"""
  11. def __init__(self, fs, target_fs=10, axis=-1):
  12. """初始化函数,设置原始采样频率 fs,目标采样频率 target_fs,以及降采样操作的轴 axis。"""
  13. self.fs = fs
  14. self.target_fs = target_fs
  15. self.axis = axis
  16. def fit(self, X, y=None):
  17. return self
  18. def transform(self, X, y=None):
  19. """
  20. 对输入数据 X 进行两次降采样操作,以达到目标采样频率。
  21. 使用 scipy.signal.decimate 方法进行降采样,采用零相位滤波以减少延迟。
  22. """
  23. decimate_rate = np.sqrt(self.fs / self.target_fs).astype(np.int16)
  24. X = signal.decimate(X, decimate_rate, axis=self.axis, zero_phase=True)
  25. # to 10Hz
  26. X = signal.decimate(X, decimate_rate, axis=self.axis, zero_phase=True)
  27. return X
  28. class ChannelScaler(BaseEstimator, TransformerMixin):
  29. """
  30. ChannelScaler 类用于对信号的每个通道进行标准化处理。
  31. """
  32. def __init__(self, norm_axis=(0, 2)):
  33. self.channel_mean_ = None
  34. self.channel_std_ = None
  35. self.norm_axis=norm_axis
  36. def fit(self, X, y=None):
  37. '''
  38. :param X: 3d array with shape (n_epochs, n_channels, n_times)
  39. :param y:
  40. :return:
  41. '''
  42. self.channel_mean_ = np.mean(X, axis=self.norm_axis, keepdims=True)
  43. self.channel_std_ = np.std(X, axis=self.norm_axis, keepdims=True)
  44. return self
  45. def transform(self, X, y=None):
  46. X = X.copy()
  47. X -= self.channel_mean_
  48. X /= self.channel_std_
  49. return X
  50. def riemann_feature_embedder(feat_dim, estimator='lwf'):
  51. """
  52. 创建一个特征嵌入管道,利用 Riemann 几何方法进行特征提取和转换。
  53. 参数 feat_dim 定义每个数据块的大小,estimator 选择协方差矩阵的估计方法。
  54. 管道包括通道标准化、块协方差矩阵计算、白化处理以及切换到切线空间的步骤。
  55. """
  56. return make_pipeline(
  57. ChannelScaler(), # not necessary
  58. BlockCovariances(block_size=feat_dim, estimator=estimator),
  59. Whitening(metric='riemann', dim_red={'expl_var': 0.99}),
  60. TangentSpace()
  61. )
  62. def baseline_feature_embedder(fs, target_fs, axis):
  63. """
  64. 创建一个基线特征嵌入管道,主要用于降采样和通道标准化。
  65. 参数 fs, target_fs, axis 分别为原始采样频率、目标采样频率和降采样操作的轴。
  66. 管道包括降采样、通道标准化和向量化处理的步骤。
  67. """
  68. return make_pipeline(
  69. DecimateFeature(fs, target_fs, axis),
  70. ChannelScaler(),
  71. Vectorizer()
  72. )
  73. def cps_feature_embedder(n_chs):
  74. return make_pipeline(
  75. ChannelScaler(),
  76. CSP(n_chs, reg='ledoit_wolf', log=True)
  77. )