123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- """对信号进行预处理,主要用于在线/离线算法、绘图之前"""
- from typing import List
- from typing import Optional
- import mne
- import numpy as np
- from scipy import signal
- class PreProcessor(object):
- """信号预处理,包含去基漂,滤波,重参考以及重采样"""
- @classmethod
- def re_reference(cls,
- mne_raw_data,
- methods="average",
- ref_channels: Optional[List[str]] = None):
- """对数据做重参考,主要提供三种常见重参考方法,共平均,按导联,双极导联
- Args:
- mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
- methods (str, optional): "single"按导联,biopolar双极导联,默认为"average"共平均.
- ref_channels (Optional[List[str]], optional): 默认为None,指定时为一个导联标签的列表
- Returns:
- class: mne.io.array.array.RawArray
- """
- if methods == "single":
- return mne_raw_data.copy().set_eeg_reference(
- ref_channels=ref_channels)
- elif methods == "biopolar":
- return mne.set_bipolar_reference(mne_raw_data,
- anode=ref_channels[0],
- cathode=ref_channels[1])
- elif methods == "average":
- return mne_raw_data.copy().set_eeg_reference(
- ref_channels="average")
- @classmethod
- def detrend(cls, mne_raw_data):
- """去基漂-去均值
- 注意:在处理模拟数据(如正弦信号)时,若不足一个周期,处理结果不符预期
- Args:
- mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
- Returns:
- class: mne.io.array.array.RawArray
- """
- sig_mean = np.mean(mne_raw_data.get_data(), axis=1)
- sig_detrended = mne_raw_data.get_data() - sig_mean.reshape(
- sig_mean.shape[0], 1)
- return mne.io.RawArray(sig_detrended, mne_raw_data.info)
- @classmethod
- def detrend_by_linear(cls, mne_raw_data):
- """去基漂-去线性
- Args:
- mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
- Returns:
- class: mne.io.array.array.RawArray
- """
- # axis=0 列方向处理
- sig_detrended = signal.detrend(mne_raw_data.get_data(), axis=-1)
- return mne.io.RawArray(sig_detrended, mne_raw_data.info)
- @classmethod
- def filter(cls,
- mne_raw_data,
- l_freq: Optional[int] = 0.1,
- h_freq: Optional[int] = 40):
- """滤波
- l_freq<h_freq:band pass;
- l_freq>h_freq:band stop;
- l_freq is not None and h_freq is None: high pass;
- l_freq is None and h_freq is not None: low pass
- Args:
- mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
- l_freq (Optional[int], optional): low截至频率. Defaults to 0.1.
- h_freq (Optional[int], optional): high截至频率. Defaults to 40.
- Returns:
- class: mne.io.array.array.RawArray
- """
- return mne_raw_data.copy().filter(l_freq=l_freq, h_freq=h_freq)
- @classmethod
- def resample(cls, mne_raw_data, new_freq):
- """降采样,该函数为了避免混叠,降采样之前会做滤波
- Args:
- mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
- new_freq (float): 新的采样率
- Returns:
- class: mne.io.array.array.RawArray
- """
- return mne_raw_data.copy().resample(sfreq=new_freq)
- @classmethod
- def resample_direct(cls, mne_raw_data, new_freq):
- """降采样,直接间隔抽样
- Args:
- mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
- new_freq (float): 新的采样率
- Returns:
- class: mne.io.array.array.RawArray
- """
- sig = mne_raw_data.get_data()
- step = mne_raw_data.info["sfreq"] / new_freq
- assert len(sig[0]) % step == 0, \
- f"Length if sig ({len(sig)}) can not divided by step({step})"
- sig_resampled = sig[:, 0::int(step)]
- return mne.io.RawArray(sig_resampled, mne_raw_data.info)
- class RealTimeFilter(object):
- """ 实时滤波器
- 输入设计好的滤波器参数(_ce_a, _ce_b),实时滤波。
- y(n) = b(0)*x(n)+b(1)*x(n-1)...-a(1)*y(n-1)-a(2)*y(n-2)...
- Attribute:
- _order_a: 分母阶数
- _order_b: 分子阶数
- _buffer_x: 历史输入
- _buffer_y: 历史输出
- _pos_x: 最新数据点的位置
- _pos_y:
- _ce_a: 分母系数
- _ce_b: 分子系数
- """
- def __init__(self, ce_a: List[float], ce_b: List[float]):
- self._order_a = len(ce_a)
- self._order_b = len(ce_b)
- self._ce_a = ce_a
- self._ce_b = ce_b
- # 环形,old<- ->new
- self._buffer_x = [0.0] * self._order_b # 存储x(n-N)...x(n-1)
- self._buffer_y = [0.0] * self._order_a # 存储y(n-N)...y(n-1)
- self._pos_x = 0 # x(n)的存放位置
- self._pos_y = 0 # y(n)的存放位置
- def filter(self, xn):
- self._buffer_x[self._pos_x] = xn
- weighted_sum_x = self.cal_weighted_sum_x()
- weighted_sum_y = self.cal_weighted_sum_y()
- yn = weighted_sum_x - weighted_sum_y
- self._buffer_y[self._pos_y] = yn
- self._pos_x += 1
- if self._pos_x == self._order_b:
- self._pos_x = 0
- self._pos_y += 1
- if self._pos_y == self._order_a:
- self._pos_y = 0
- return yn
- def cal_weighted_sum_x(self):
- # b(0)*x(n)+b(1)*x(n-1)...b(N-1)*x(n-N+1)
- weighted_sum_x = 0
- for ii in range(self._order_b):
- pos_x = (self._pos_x - ii + self._order_b) % self._order_b
- weighted_sum_x += self._ce_b[ii] * self._buffer_x[pos_x]
- return weighted_sum_x
- def cal_weighted_sum_y(self):
- # a(1)*y(n-1)+a(2)*y(n-2)...+a(N-1)*y(n-N+1)
- weighted_sum_y = 0
- for ii in range(1, self._order_a):
- pos_y = (self._pos_y - ii + self._order_a) % self._order_a
- weighted_sum_y += self._ce_a[ii] * self._buffer_y[pos_y]
- return weighted_sum_y
- @classmethod
- def init_eeg(cls, code, fs=1000):
- """初始化eeg常用实时滤波器
- Args:
- code (int): 预处理类型. 0代表0.5Hz高通,1代表60Hz低通.
- fs (float, optional): 采样率
- Returns:
- RealTimFilter: 实时滤波器实例
- """
- assert code in [0, 1, 2], "Invalid code for eeg RealTimeFilter init!"
- if code == 0:
- # butter 0.5Hz高通
- # aa = [1, -1.982228929792529, 0.982385450614125]
- # bb = [0.991153595101663, -1.982307190203327, 0.991153595101663]
- bb, aa = signal.butter(2, [2*0.5/fs], "hp")
- elif code == 1:
- # 60Hz低通
- # aa = [1, -0.031426266043351]
- # bb = [0.484286866978324, 0.484286866978324]
- bb, aa = signal.butter(1, [2*60/fs])
- elif code == 2:
- # 40Hz低通
- # aa = [1, -0.290526856731916]
- # bb = [0.354736571634042, 0.354736571634042]
- bb, aa = signal.butter(1, [2*40/fs])
- return cls(aa, bb)
- class RealTimeFilterM(object):
- """ 对多个通道同时进行实时滤波器
- 输入设计好的滤波器参数(_ce_a, _ce_b),对每个通道进行实时滤波:
- y(n) = b(0)*x(n)+b(1)*x(n-1)...-a(1)*y(n-1)-a(2)*y(n-2)...
- Attribute:
- _order_a: 分母阶数
- _order_b: 分子阶数
- _channel: 信号的通道数
- _buffer_x: 历史输入
- _buffer_y: 历史输出
- _pos_x: 最新数据点的位置
- _pos_y:
- _ce_a: 分母系数
- _ce_b: 分子系数
- """
- def __init__(self, ce_a: List[float], ce_b: List[float], channel: int):
- self._order_a = len(ce_a)
- self._order_b = len(ce_b)
- self._channel = channel
- self._ce_a = ce_a
- self._ce_b = ce_b
- # 环形,old<- ->new
- self._buffer_x = np.zeros((self._channel, self._order_b),
- dtype=np.float64) # 存储x(n-N)...x(n-1)
- self._buffer_y = np.zeros((self._channel, self._order_a),
- dtype=np.float64) # 存储y(n-N)...y(n-1)
- self._pos_x = 0 # x(n)的存放位置
- self._pos_y = 0 # y(n)的存放位置
- def filter(self, xn: np.ndarray):
- self._buffer_x[:, self._pos_x] = xn
- weighted_sum_x = self.cal_weighted_sum_x()
- weighted_sum_y = self.cal_weighted_sum_y()
- yn = weighted_sum_x - weighted_sum_y
- self._buffer_y[:, self._pos_y] = yn
- self._pos_x += 1
- if self._pos_x == self._order_b:
- self._pos_x = 0
- self._pos_y += 1
- if self._pos_y == self._order_a:
- self._pos_y = 0
- return yn
- def cal_weighted_sum_x(self):
- # b(0)*x(n)+b(1)*x(n-1)...b(N-1)*x(n-N+1)
- weighted_sum_x = np.zeros(self._channel, dtype=np.float64)
- for ii in range(self._order_b):
- pos_x = (self._pos_x - ii + self._order_b) % self._order_b
- weighted_sum_x += self._ce_b[ii] * self._buffer_x[:, pos_x]
- return weighted_sum_x
- def cal_weighted_sum_y(self):
- # a(1)*y(n-1)+a(2)*y(n-2)...+a(N-1)*y(n-N+1)
- weighted_sum_y = np.zeros(self._channel, dtype=np.float64)
- for ii in range(1, self._order_a):
- pos_y = (self._pos_y - ii + self._order_a) % self._order_a
- weighted_sum_y += self._ce_a[ii] * self._buffer_y[:, pos_y]
- return weighted_sum_y
- @classmethod
- def init_eeg(cls, code, channel, fs=1000):
- """初始化eeg常用实时滤波器
- Args:
- code (int): 预处理类型. 0代表0.5Hz高通,1代表60Hz低通.
- channel(int): 通道数
- fs (float, optional): 采样率. Defaults to 1000.
- Returns:
- RealTimFilterM: 实时滤波器实例
- """
- assert code in [0, 1, 2], "Invalid code for eeg RealTimeFilter init!"
- if code == 0:
- # butter 0.5Hz高通
- # aa = [1, -1.982228929792529, 0.982385450614125]
- # bb = [0.991153595101663, -1.982307190203327, 0.991153595101663]
- bb, aa = signal.butter(2, [2*0.5/fs], "hp")
- elif code == 1:
- # 60Hz低通
- # aa = [1, -0.031426266043351]
- # bb = [0.484286866978324, 0.484286866978324]
- bb, aa = signal.butter(1, [2*60/fs])
- elif code == 2:
- # 40Hz低通
- # aa = [1, -0.290526856731916]
- # bb = [0.354736571634042, 0.354736571634042]
- bb, aa = signal.butter(1, [2*40/fs])
- return cls(aa, bb, channel)
|