"""对信号进行预处理,主要用于在线/离线算法、绘图之前""" from typing import List from typing import Optional import mne import numpy as np from scipy import signal class PreProcessor(object): """信号预处理,包含去基漂,滤波,重参考以及重采样""" @classmethod def re_reference(cls, mne_raw_data, methods="average", ref_channels: Optional[List[str]] = None): """对数据做重参考,主要提供三种常见重参考方法,共平均,按导联,双极导联 Args: mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据" methods (str, optional): "single"按导联,biopolar双极导联,默认为"average"共平均. ref_channels (Optional[List[str]], optional): 默认为None,指定时为一个导联标签的列表 Returns: class: mne.io.array.array.RawArray """ if methods == "single": return mne_raw_data.copy().set_eeg_reference( ref_channels=ref_channels) elif methods == "biopolar": return mne.set_bipolar_reference(mne_raw_data, anode=ref_channels[0], cathode=ref_channels[1]) elif methods == "average": return mne_raw_data.copy().set_eeg_reference( ref_channels="average") @classmethod def detrend(cls, mne_raw_data): """去基漂-去均值 注意:在处理模拟数据(如正弦信号)时,若不足一个周期,处理结果不符预期 Args: mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据" Returns: class: mne.io.array.array.RawArray """ sig_mean = np.mean(mne_raw_data.get_data(), axis=1) sig_detrended = mne_raw_data.get_data() - sig_mean.reshape( sig_mean.shape[0], 1) return mne.io.RawArray(sig_detrended, mne_raw_data.info) @classmethod def detrend_by_linear(cls, mne_raw_data): """去基漂-去线性 Args: mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据" Returns: class: mne.io.array.array.RawArray """ # axis=0 列方向处理 sig_detrended = signal.detrend(mne_raw_data.get_data(), axis=-1) return mne.io.RawArray(sig_detrended, mne_raw_data.info) @classmethod def filter(cls, mne_raw_data, l_freq: Optional[int] = 0.1, h_freq: Optional[int] = 40): """滤波 l_freqh_freq:band stop; l_freq is not None and h_freq is None: high pass; l_freq is None and h_freq is not None: low pass Args: mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据" l_freq (Optional[int], optional): low截至频率. Defaults to 0.1. h_freq (Optional[int], optional): high截至频率. Defaults to 40. Returns: class: mne.io.array.array.RawArray """ return mne_raw_data.copy().filter(l_freq=l_freq, h_freq=h_freq) @classmethod def resample(cls, mne_raw_data, new_freq): """降采样,该函数为了避免混叠,降采样之前会做滤波 Args: mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据" new_freq (float): 新的采样率 Returns: class: mne.io.array.array.RawArray """ return mne_raw_data.copy().resample(sfreq=new_freq) @classmethod def resample_direct(cls, mne_raw_data, new_freq): """降采样,直接间隔抽样 Args: mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据" new_freq (float): 新的采样率 Returns: class: mne.io.array.array.RawArray """ sig = mne_raw_data.get_data() step = mne_raw_data.info["sfreq"] / new_freq assert len(sig[0]) % step == 0, \ f"Length if sig ({len(sig)}) can not divided by step({step})" sig_resampled = sig[:, 0::int(step)] return mne.io.RawArray(sig_resampled, mne_raw_data.info) class RealTimeFilter(object): """ 实时滤波器 输入设计好的滤波器参数(_ce_a, _ce_b),实时滤波。 y(n) = b(0)*x(n)+b(1)*x(n-1)...-a(1)*y(n-1)-a(2)*y(n-2)... Attribute: _order_a: 分母阶数 _order_b: 分子阶数 _buffer_x: 历史输入 _buffer_y: 历史输出 _pos_x: 最新数据点的位置 _pos_y: _ce_a: 分母系数 _ce_b: 分子系数 """ def __init__(self, ce_a: List[float], ce_b: List[float]): self._order_a = len(ce_a) self._order_b = len(ce_b) self._ce_a = ce_a self._ce_b = ce_b # 环形,old<- ->new self._buffer_x = [0.0] * self._order_b # 存储x(n-N)...x(n-1) self._buffer_y = [0.0] * self._order_a # 存储y(n-N)...y(n-1) self._pos_x = 0 # x(n)的存放位置 self._pos_y = 0 # y(n)的存放位置 def filter(self, xn): self._buffer_x[self._pos_x] = xn weighted_sum_x = self.cal_weighted_sum_x() weighted_sum_y = self.cal_weighted_sum_y() yn = weighted_sum_x - weighted_sum_y self._buffer_y[self._pos_y] = yn self._pos_x += 1 if self._pos_x == self._order_b: self._pos_x = 0 self._pos_y += 1 if self._pos_y == self._order_a: self._pos_y = 0 return yn def cal_weighted_sum_x(self): # b(0)*x(n)+b(1)*x(n-1)...b(N-1)*x(n-N+1) weighted_sum_x = 0 for ii in range(self._order_b): pos_x = (self._pos_x - ii + self._order_b) % self._order_b weighted_sum_x += self._ce_b[ii] * self._buffer_x[pos_x] return weighted_sum_x def cal_weighted_sum_y(self): # a(1)*y(n-1)+a(2)*y(n-2)...+a(N-1)*y(n-N+1) weighted_sum_y = 0 for ii in range(1, self._order_a): pos_y = (self._pos_y - ii + self._order_a) % self._order_a weighted_sum_y += self._ce_a[ii] * self._buffer_y[pos_y] return weighted_sum_y @classmethod def init_eeg(cls, code, fs=1000): """初始化eeg常用实时滤波器 Args: code (int): 预处理类型. 0代表0.5Hz高通,1代表60Hz低通. fs (float, optional): 采样率 Returns: RealTimFilter: 实时滤波器实例 """ assert code in [0, 1, 2], "Invalid code for eeg RealTimeFilter init!" if code == 0: # butter 0.5Hz高通 # aa = [1, -1.982228929792529, 0.982385450614125] # bb = [0.991153595101663, -1.982307190203327, 0.991153595101663] bb, aa = signal.butter(2, [2*0.5/fs], "hp") elif code == 1: # 60Hz低通 # aa = [1, -0.031426266043351] # bb = [0.484286866978324, 0.484286866978324] bb, aa = signal.butter(1, [2*60/fs]) elif code == 2: # 40Hz低通 # aa = [1, -0.290526856731916] # bb = [0.354736571634042, 0.354736571634042] bb, aa = signal.butter(1, [2*40/fs]) return cls(aa, bb) class RealTimeFilterM(object): """ 对多个通道同时进行实时滤波器 输入设计好的滤波器参数(_ce_a, _ce_b),对每个通道进行实时滤波: y(n) = b(0)*x(n)+b(1)*x(n-1)...-a(1)*y(n-1)-a(2)*y(n-2)... Attribute: _order_a: 分母阶数 _order_b: 分子阶数 _channel: 信号的通道数 _buffer_x: 历史输入 _buffer_y: 历史输出 _pos_x: 最新数据点的位置 _pos_y: _ce_a: 分母系数 _ce_b: 分子系数 """ def __init__(self, ce_a: List[float], ce_b: List[float], channel: int): self._order_a = len(ce_a) self._order_b = len(ce_b) self._channel = channel self._ce_a = ce_a self._ce_b = ce_b # 环形,old<- ->new self._buffer_x = np.zeros((self._channel, self._order_b), dtype=np.float64) # 存储x(n-N)...x(n-1) self._buffer_y = np.zeros((self._channel, self._order_a), dtype=np.float64) # 存储y(n-N)...y(n-1) self._pos_x = 0 # x(n)的存放位置 self._pos_y = 0 # y(n)的存放位置 def filter(self, xn: np.ndarray): self._buffer_x[:, self._pos_x] = xn weighted_sum_x = self.cal_weighted_sum_x() weighted_sum_y = self.cal_weighted_sum_y() yn = weighted_sum_x - weighted_sum_y self._buffer_y[:, self._pos_y] = yn self._pos_x += 1 if self._pos_x == self._order_b: self._pos_x = 0 self._pos_y += 1 if self._pos_y == self._order_a: self._pos_y = 0 return yn def cal_weighted_sum_x(self): # b(0)*x(n)+b(1)*x(n-1)...b(N-1)*x(n-N+1) weighted_sum_x = np.zeros(self._channel, dtype=np.float64) for ii in range(self._order_b): pos_x = (self._pos_x - ii + self._order_b) % self._order_b weighted_sum_x += self._ce_b[ii] * self._buffer_x[:, pos_x] return weighted_sum_x def cal_weighted_sum_y(self): # a(1)*y(n-1)+a(2)*y(n-2)...+a(N-1)*y(n-N+1) weighted_sum_y = np.zeros(self._channel, dtype=np.float64) for ii in range(1, self._order_a): pos_y = (self._pos_y - ii + self._order_a) % self._order_a weighted_sum_y += self._ce_a[ii] * self._buffer_y[:, pos_y] return weighted_sum_y @classmethod def init_eeg(cls, code, channel, fs=1000): """初始化eeg常用实时滤波器 Args: code (int): 预处理类型. 0代表0.5Hz高通,1代表60Hz低通. channel(int): 通道数 fs (float, optional): 采样率. Defaults to 1000. Returns: RealTimFilterM: 实时滤波器实例 """ assert code in [0, 1, 2], "Invalid code for eeg RealTimeFilter init!" if code == 0: # butter 0.5Hz高通 # aa = [1, -1.982228929792529, 0.982385450614125] # bb = [0.991153595101663, -1.982307190203327, 0.991153595101663] bb, aa = signal.butter(2, [2*0.5/fs], "hp") elif code == 1: # 60Hz低通 # aa = [1, -0.031426266043351] # bb = [0.484286866978324, 0.484286866978324] bb, aa = signal.butter(1, [2*60/fs]) elif code == 2: # 40Hz低通 # aa = [1, -0.290526856731916] # bb = [0.354736571634042, 0.354736571634042] bb, aa = signal.butter(1, [2*40/fs]) return cls(aa, bb, channel)