sig_buffer.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. """环形buffer,用来缓存一定长度的数据"""
  2. import collections
  3. import itertools
  4. import math
  5. import mne
  6. import numpy as np
  7. from core.sig_chain.device.montage_base_model import MontageBase
  8. from core.sig_chain.utils import Observer
  9. class ParserNewset():
  10. """策略方法类
  11. """
  12. def parser_newset(self,
  13. package_num,
  14. content,
  15. mbm_created,
  16. dataformat="mne"):
  17. """策略类方法接口
  18. Args:
  19. package_num (int): 包的数量
  20. content (dqueue): 数据队列
  21. mbm_created (class): mne 的info
  22. dataformat (str, optional): 数据类型,默认为mne格式,目前dataformat为任意其它
  23. 值都会返回nparray格式. Defaults to "mne".
  24. """
  25. pass
  26. class ParserNewsetWithTime(ParserNewset):
  27. """类策略方法:解析有时间戳的数据
  28. Args:
  29. ParserNewset (class): 父类
  30. """
  31. def parser_newset(self,
  32. package_num,
  33. content,
  34. mbm_created,
  35. dataformat="mne"):
  36. """解析数据有时间戳的数据
  37. Args:
  38. package_num (int): 包数量
  39. content (class): dqueue
  40. mbm_created (class): mne 的info
  41. dataformat (str, optional): 数据类型,默认为mne格式,目前dataformat为任意其它
  42. 值都会返回nparray格式. Defaults to "mne".
  43. Returns:
  44. dict: 返回数据和状态和时间戳
  45. """
  46. status = "unknown"
  47. if content and len(content) >= package_num:
  48. data_list = []
  49. time_list = []
  50. for con in list(content):
  51. data_list.append(con.data)
  52. time_list.append(con.timestamp)
  53. signals = np.concatenate(data_list, axis=1)
  54. status = "ok"
  55. raw_data = mne.io.RawArray(
  56. signals, mbm_created.info) if dataformat == "mne" else signals
  57. return {"status": status, "data": raw_data, "timestamp": time_list}
  58. else:
  59. return {"status": "warn", "data": None, "timestamp": None}
  60. class PaserNewsetWithoutTime(ParserNewset):
  61. """类策略方法:解析没有时间戳的数据
  62. Args:
  63. ParserNewset (class): 父类
  64. """
  65. def parser_newset(self,
  66. package_num,
  67. content,
  68. mbm_created,
  69. dataformat="mne"):
  70. """解析数据没有时间戳的数据
  71. Args:
  72. package_num (int): 包数量
  73. content (class): dqueue
  74. mbm_created (class): mne 的info
  75. dataformat (str, optional): 数据类型,默认为mne格式,目前dataformat为任意其它
  76. 值都会返回nparray格式. Defaults to "mne".
  77. Returns:
  78. dict: 返回数据和状态
  79. """
  80. status = "unknown"
  81. if content and len(content) >= package_num:
  82. signals = np.concatenate(tuple(
  83. list(itertools.islice(content, 0, None))),
  84. axis=1)
  85. status = "ok"
  86. raw_data = mne.io.RawArray(
  87. signals, mbm_created.info) if dataformat == "mne" else signals
  88. return {"status": status, "data": raw_data}
  89. else:
  90. return {"status": "warn", "data": None}
  91. class CircularBuffer(Observer):
  92. """环形buffer类"""
  93. def __init__(self, data_len, package_len, chan_labels, chan_types, fs,
  94. parser):
  95. """初始化一个环形buffer
  96. Args:
  97. data_len (float): 数据长度,以秒为单位,例如要缓存20s的数据,data_len值为20
  98. package_len (float): 包长度,以秒为单位,例如设备每100ms发送一个包,则package_len值为0.1
  99. chan_labels (List[str]): 导联标签
  100. chan_types (List[str]): 可以是任意str,一般写为"eeg"即可,注意要对每个chan_labels都定义
  101. fs (float): 采样率
  102. parser (class): 数据解析,来自于ParserNewset的类策略
  103. """
  104. self.data_len = data_len
  105. self.package_len = package_len
  106. self.package_num = math.ceil(self.data_len / self.package_len)
  107. self.chan_labels = chan_labels
  108. self.fs = fs
  109. self.content = collections.deque(maxlen=self.package_num)
  110. self.mbm_created = MontageBase(chan_labels, chan_types, fs)
  111. self._shape_status = {"ok": "ok", "warn": "warn"}
  112. self.parser = parser
  113. def update(self, newset):
  114. """更新buffer中的数据
  115. Args:
  116. newset (np array or other): 设备定时发来的数据,一般为chan_count*samples的二维矩阵
  117. """
  118. # if newset.any():
  119. # if newset:
  120. self.content.append(newset)
  121. # else:
  122. # pass
  123. def get_sig(self, dataformat="mne", clear=True):
  124. """获得数据并转为mne格式
  125. Args:
  126. dataformat (str): 数据类型,默认为mne格式,目前dataformat为任意其它值都会返回nparray格式
  127. clear (bool): 是否清空buffer的标志,默认为清空
  128. Returns:
  129. dict: 一个字典,"status"表示得到的数据维度是否正确,"ok"表示正确,"warn"表示维度和预期不相符;
  130. "data"默认为mne格式,也可以为nparray,根据策略方法不同,需要时也会有时间戳的输出
  131. """
  132. ret = self.parser.parser_newset(self.package_num, self.content,
  133. self.mbm_created, dataformat)
  134. if ret["status"] == "ok" and clear:
  135. self.content.clear()
  136. return ret