1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import numpy as np
- from scipy import signal
- from pyriemann.estimation import BlockCovariances
- from pyriemann.tangentspace import TangentSpace
- from pyriemann.preprocessing import Whitening
- from sklearn.pipeline import make_pipeline
- from sklearn.base import BaseEstimator, TransformerMixin
- from mne.decoding import Vectorizer, CSP
- class DecimateFeature(BaseEstimator, TransformerMixin):
- """DecimateFeature 类用于对信号进行降采样以达到目标采样频率。"""
- def __init__(self, fs, target_fs=10, axis=-1):
- """初始化函数,设置原始采样频率 fs,目标采样频率 target_fs,以及降采样操作的轴 axis。"""
- self.fs = fs
- self.target_fs = target_fs
- self.axis = axis
- def fit(self, X, y=None):
- return self
-
- def transform(self, X, y=None):
- """
- 对输入数据 X 进行两次降采样操作,以达到目标采样频率。
- 使用 scipy.signal.decimate 方法进行降采样,采用零相位滤波以减少延迟。
- """
- decimate_rate = np.sqrt(self.fs / self.target_fs).astype(np.int16)
- X = signal.decimate(X, decimate_rate, axis=self.axis, zero_phase=True)
- # to 10Hz
- X = signal.decimate(X, decimate_rate, axis=self.axis, zero_phase=True)
- return X
- class ChannelScaler(BaseEstimator, TransformerMixin):
- """
- ChannelScaler 类用于对信号的每个通道进行标准化处理。
- """
- def __init__(self, norm_axis=(0, 2)):
- self.channel_mean_ = None
- self.channel_std_ = None
- self.norm_axis=norm_axis
- def fit(self, X, y=None):
- '''
- :param X: 3d array with shape (n_epochs, n_channels, n_times)
- :param y:
- :return:
- '''
- self.channel_mean_ = np.mean(X, axis=self.norm_axis, keepdims=True)
- self.channel_std_ = np.std(X, axis=self.norm_axis, keepdims=True)
- return self
- def transform(self, X, y=None):
- X = X.copy()
- X -= self.channel_mean_
- X /= self.channel_std_
- return X
- def riemann_feature_embedder(feat_dim, estimator='lwf'):
- """
- 创建一个特征嵌入管道,利用 Riemann 几何方法进行特征提取和转换。
- 参数 feat_dim 定义每个数据块的大小,estimator 选择协方差矩阵的估计方法。
- 管道包括通道标准化、块协方差矩阵计算、白化处理以及切换到切线空间的步骤。
- """
- return make_pipeline(
- ChannelScaler(), # not necessary
- BlockCovariances(block_size=feat_dim, estimator=estimator),
- Whitening(metric='riemann', dim_red={'expl_var': 0.99}),
- TangentSpace()
- )
- def baseline_feature_embedder(fs, target_fs, axis):
- """
- 创建一个基线特征嵌入管道,主要用于降采样和通道标准化。
- 参数 fs, target_fs, axis 分别为原始采样频率、目标采样频率和降采样操作的轴。
- 管道包括降采样、通道标准化和向量化处理的步骤。
- """
- return make_pipeline(
- DecimateFeature(fs, target_fs, axis),
- ChannelScaler(),
- Vectorizer()
- )
- def cps_feature_embedder(n_chs):
- return make_pipeline(
- ChannelScaler(),
- CSP(n_chs, reg='ledoit_wolf', log=True)
- )
|