pre_process.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. """对信号进行预处理,主要用于在线/离线算法、绘图之前"""
  2. from typing import List
  3. from typing import Optional
  4. import mne
  5. import numpy as np
  6. from scipy import signal
  7. class PreProcessor(object):
  8. """信号预处理,包含去基漂,滤波,重参考以及重采样"""
  9. @classmethod
  10. def re_reference(cls,
  11. mne_raw_data,
  12. methods="average",
  13. ref_channels: Optional[List[str]] = None):
  14. """对数据做重参考,主要提供三种常见重参考方法,共平均,按导联,双极导联
  15. Args:
  16. mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
  17. methods (str, optional): "single"按导联,biopolar双极导联,默认为"average"共平均.
  18. ref_channels (Optional[List[str]], optional): 默认为None,指定时为一个导联标签的列表
  19. Returns:
  20. class: mne.io.array.array.RawArray
  21. """
  22. if methods == "single":
  23. return mne_raw_data.copy().set_eeg_reference(
  24. ref_channels=ref_channels)
  25. elif methods == "biopolar":
  26. return mne.set_bipolar_reference(mne_raw_data,
  27. anode=ref_channels[0],
  28. cathode=ref_channels[1])
  29. elif methods == "average":
  30. return mne_raw_data.copy().set_eeg_reference(
  31. ref_channels="average")
  32. @classmethod
  33. def detrend(cls, mne_raw_data):
  34. """去基漂-去均值
  35. 注意:在处理模拟数据(如正弦信号)时,若不足一个周期,处理结果不符预期
  36. Args:
  37. mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
  38. Returns:
  39. class: mne.io.array.array.RawArray
  40. """
  41. sig_mean = np.mean(mne_raw_data.get_data(), axis=1)
  42. sig_detrended = mne_raw_data.get_data() - sig_mean.reshape(
  43. sig_mean.shape[0], 1)
  44. return mne.io.RawArray(sig_detrended, mne_raw_data.info)
  45. @classmethod
  46. def detrend_by_linear(cls, mne_raw_data):
  47. """去基漂-去线性
  48. Args:
  49. mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
  50. Returns:
  51. class: mne.io.array.array.RawArray
  52. """
  53. # axis=0 列方向处理
  54. sig_detrended = signal.detrend(mne_raw_data.get_data(), axis=-1)
  55. return mne.io.RawArray(sig_detrended, mne_raw_data.info)
  56. @classmethod
  57. def filter(cls,
  58. mne_raw_data,
  59. l_freq: Optional[int] = 0.1,
  60. h_freq: Optional[int] = 40):
  61. """滤波
  62. l_freq<h_freq:band pass;
  63. l_freq>h_freq:band stop;
  64. l_freq is not None and h_freq is None: high pass;
  65. l_freq is None and h_freq is not None: low pass
  66. Args:
  67. mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
  68. l_freq (Optional[int], optional): low截至频率. Defaults to 0.1.
  69. h_freq (Optional[int], optional): high截至频率. Defaults to 40.
  70. Returns:
  71. class: mne.io.array.array.RawArray
  72. """
  73. return mne_raw_data.copy().filter(l_freq=l_freq, h_freq=h_freq)
  74. @classmethod
  75. def resample(cls, mne_raw_data, new_freq):
  76. """降采样,该函数为了避免混叠,降采样之前会做滤波
  77. Args:
  78. mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
  79. new_freq (float): 新的采样率
  80. Returns:
  81. class: mne.io.array.array.RawArray
  82. """
  83. return mne_raw_data.copy().resample(sfreq=new_freq)
  84. @classmethod
  85. def resample_direct(cls, mne_raw_data, new_freq):
  86. """降采样,直接间隔抽样
  87. Args:
  88. mne_raw_data (mne.io.array.array.RawArray): "mne格式的数据"
  89. new_freq (float): 新的采样率
  90. Returns:
  91. class: mne.io.array.array.RawArray
  92. """
  93. sig = mne_raw_data.get_data()
  94. step = mne_raw_data.info["sfreq"] / new_freq
  95. assert len(sig[0]) % step == 0, \
  96. f"Length if sig ({len(sig)}) can not divided by step({step})"
  97. sig_resampled = sig[:, 0::int(step)]
  98. return mne.io.RawArray(sig_resampled, mne_raw_data.info)
  99. class RealTimeFilter(object):
  100. """ 实时滤波器
  101. 输入设计好的滤波器参数(_ce_a, _ce_b),实时滤波。
  102. y(n) = b(0)*x(n)+b(1)*x(n-1)...-a(1)*y(n-1)-a(2)*y(n-2)...
  103. Attribute:
  104. _order_a: 分母阶数
  105. _order_b: 分子阶数
  106. _buffer_x: 历史输入
  107. _buffer_y: 历史输出
  108. _pos_x: 最新数据点的位置
  109. _pos_y:
  110. _ce_a: 分母系数
  111. _ce_b: 分子系数
  112. """
  113. def __init__(self, ce_a: List[float], ce_b: List[float]):
  114. self._order_a = len(ce_a)
  115. self._order_b = len(ce_b)
  116. self._ce_a = ce_a
  117. self._ce_b = ce_b
  118. # 环形,old<- ->new
  119. self._buffer_x = [0.0] * self._order_b # 存储x(n-N)...x(n-1)
  120. self._buffer_y = [0.0] * self._order_a # 存储y(n-N)...y(n-1)
  121. self._pos_x = 0 # x(n)的存放位置
  122. self._pos_y = 0 # y(n)的存放位置
  123. def filter(self, xn):
  124. self._buffer_x[self._pos_x] = xn
  125. weighted_sum_x = self.cal_weighted_sum_x()
  126. weighted_sum_y = self.cal_weighted_sum_y()
  127. yn = weighted_sum_x - weighted_sum_y
  128. self._buffer_y[self._pos_y] = yn
  129. self._pos_x += 1
  130. if self._pos_x == self._order_b:
  131. self._pos_x = 0
  132. self._pos_y += 1
  133. if self._pos_y == self._order_a:
  134. self._pos_y = 0
  135. return yn
  136. def cal_weighted_sum_x(self):
  137. # b(0)*x(n)+b(1)*x(n-1)...b(N-1)*x(n-N+1)
  138. weighted_sum_x = 0
  139. for ii in range(self._order_b):
  140. pos_x = (self._pos_x - ii + self._order_b) % self._order_b
  141. weighted_sum_x += self._ce_b[ii] * self._buffer_x[pos_x]
  142. return weighted_sum_x
  143. def cal_weighted_sum_y(self):
  144. # a(1)*y(n-1)+a(2)*y(n-2)...+a(N-1)*y(n-N+1)
  145. weighted_sum_y = 0
  146. for ii in range(1, self._order_a):
  147. pos_y = (self._pos_y - ii + self._order_a) % self._order_a
  148. weighted_sum_y += self._ce_a[ii] * self._buffer_y[pos_y]
  149. return weighted_sum_y
  150. @classmethod
  151. def init_eeg(cls, code, fs=1000):
  152. """初始化eeg常用实时滤波器
  153. Args:
  154. code (int): 预处理类型. 0代表0.5Hz高通,1代表60Hz低通.
  155. fs (float, optional): 采样率
  156. Returns:
  157. RealTimFilter: 实时滤波器实例
  158. """
  159. assert code in [0, 1, 2], "Invalid code for eeg RealTimeFilter init!"
  160. if code == 0:
  161. # butter 0.5Hz高通
  162. # aa = [1, -1.982228929792529, 0.982385450614125]
  163. # bb = [0.991153595101663, -1.982307190203327, 0.991153595101663]
  164. bb, aa = signal.butter(2, [2*0.5/fs], "hp")
  165. elif code == 1:
  166. # 60Hz低通
  167. # aa = [1, -0.031426266043351]
  168. # bb = [0.484286866978324, 0.484286866978324]
  169. bb, aa = signal.butter(1, [2*60/fs])
  170. elif code == 2:
  171. # 40Hz低通
  172. # aa = [1, -0.290526856731916]
  173. # bb = [0.354736571634042, 0.354736571634042]
  174. bb, aa = signal.butter(1, [2*40/fs])
  175. return cls(aa, bb)
  176. class RealTimeFilterM(object):
  177. """ 对多个通道同时进行实时滤波器
  178. 输入设计好的滤波器参数(_ce_a, _ce_b),对每个通道进行实时滤波:
  179. y(n) = b(0)*x(n)+b(1)*x(n-1)...-a(1)*y(n-1)-a(2)*y(n-2)...
  180. Attribute:
  181. _order_a: 分母阶数
  182. _order_b: 分子阶数
  183. _channel: 信号的通道数
  184. _buffer_x: 历史输入
  185. _buffer_y: 历史输出
  186. _pos_x: 最新数据点的位置
  187. _pos_y:
  188. _ce_a: 分母系数
  189. _ce_b: 分子系数
  190. """
  191. def __init__(self, ce_a: List[float], ce_b: List[float], channel: int):
  192. self._order_a = len(ce_a)
  193. self._order_b = len(ce_b)
  194. self._channel = channel
  195. self._ce_a = ce_a
  196. self._ce_b = ce_b
  197. # 环形,old<- ->new
  198. self._buffer_x = np.zeros((self._channel, self._order_b),
  199. dtype=np.float64) # 存储x(n-N)...x(n-1)
  200. self._buffer_y = np.zeros((self._channel, self._order_a),
  201. dtype=np.float64) # 存储y(n-N)...y(n-1)
  202. self._pos_x = 0 # x(n)的存放位置
  203. self._pos_y = 0 # y(n)的存放位置
  204. def filter(self, xn: np.ndarray):
  205. self._buffer_x[:, self._pos_x] = xn
  206. weighted_sum_x = self.cal_weighted_sum_x()
  207. weighted_sum_y = self.cal_weighted_sum_y()
  208. yn = weighted_sum_x - weighted_sum_y
  209. self._buffer_y[:, self._pos_y] = yn
  210. self._pos_x += 1
  211. if self._pos_x == self._order_b:
  212. self._pos_x = 0
  213. self._pos_y += 1
  214. if self._pos_y == self._order_a:
  215. self._pos_y = 0
  216. return yn
  217. def cal_weighted_sum_x(self):
  218. # b(0)*x(n)+b(1)*x(n-1)...b(N-1)*x(n-N+1)
  219. weighted_sum_x = np.zeros(self._channel, dtype=np.float64)
  220. for ii in range(self._order_b):
  221. pos_x = (self._pos_x - ii + self._order_b) % self._order_b
  222. weighted_sum_x += self._ce_b[ii] * self._buffer_x[:, pos_x]
  223. return weighted_sum_x
  224. def cal_weighted_sum_y(self):
  225. # a(1)*y(n-1)+a(2)*y(n-2)...+a(N-1)*y(n-N+1)
  226. weighted_sum_y = np.zeros(self._channel, dtype=np.float64)
  227. for ii in range(1, self._order_a):
  228. pos_y = (self._pos_y - ii + self._order_a) % self._order_a
  229. weighted_sum_y += self._ce_a[ii] * self._buffer_y[:, pos_y]
  230. return weighted_sum_y
  231. @classmethod
  232. def init_eeg(cls, code, channel, fs=1000):
  233. """初始化eeg常用实时滤波器
  234. Args:
  235. code (int): 预处理类型. 0代表0.5Hz高通,1代表60Hz低通.
  236. channel(int): 通道数
  237. fs (float, optional): 采样率. Defaults to 1000.
  238. Returns:
  239. RealTimFilterM: 实时滤波器实例
  240. """
  241. assert code in [0, 1, 2], "Invalid code for eeg RealTimeFilter init!"
  242. if code == 0:
  243. # butter 0.5Hz高通
  244. # aa = [1, -1.982228929792529, 0.982385450614125]
  245. # bb = [0.991153595101663, -1.982307190203327, 0.991153595101663]
  246. bb, aa = signal.butter(2, [2*0.5/fs], "hp")
  247. elif code == 1:
  248. # 60Hz低通
  249. # aa = [1, -0.031426266043351]
  250. # bb = [0.484286866978324, 0.484286866978324]
  251. bb, aa = signal.butter(1, [2*60/fs])
  252. elif code == 2:
  253. # 40Hz低通
  254. # aa = [1, -0.290526856731916]
  255. # bb = [0.354736571634042, 0.354736571634042]
  256. bb, aa = signal.butter(1, [2*40/fs])
  257. return cls(aa, bb, channel)