faker.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """接收假数据
  2. Typical usage example:
  3. connector = FakerConnector()
  4. if connector.get_ready():
  5. for _ in range(20):
  6. connector.receive_wave()
  7. connector.stop()
  8. """
  9. import logging
  10. import numpy as np
  11. import socket
  12. from core.sig_chain.device.connector_interface import Connector
  13. from core.sig_chain.device.connector_interface import DataBlockInBuf
  14. from core.sig_chain.device.connector_interface import Device
  15. from core.sig_chain.utils import Observable
  16. from core.sig_chain.utils import Singleton
  17. logger = logging.getLogger(__name__)
  18. class SampleParams:
  19. def __init__(self, channel_count, sample_rate, delay_milliseconds):
  20. self.channel_count = channel_count
  21. self.channel_labels = [
  22. 'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3',
  23. 'C3', 'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2',
  24. 'T5', 'P3'
  25. ][:self.channel_count]
  26. # montage 中定义的通道类型
  27. self.channel_types = (['eeg'] * 24)[:self.channel_count]
  28. self.sample_rate = sample_rate
  29. self.delay_milliseconds = delay_milliseconds
  30. self.point_size = 4
  31. self.timestamp_size = 8
  32. self.data_count_per_channel = int(self.delay_milliseconds *
  33. self.sample_rate / 1000)
  34. self.data_block_size = self.channel_count * self.data_count_per_channel
  35. self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size
  36. self.physical_max = 20000
  37. self.physical_min = -20000
  38. def refresh(self):
  39. self.data_count_per_channel = int(self.delay_milliseconds *
  40. self.sample_rate / 1000)
  41. self.data_block_size = self.channel_count * self.data_count_per_channel
  42. self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size
  43. class FakerConnector(Connector, Singleton, Observable):
  44. def __init__(self) -> None:
  45. Observable.__init__(self)
  46. self.device = Device.FAKER
  47. self._host = '127.0.0.1'
  48. self._port = 21112
  49. self._addr = (self._host, self._port)
  50. self._sock = None
  51. self._timestamp = 0
  52. self.sample_params = SampleParams(24, 1000, 250)
  53. self._is_connected = False
  54. self.buffer_save = None
  55. self.saver = None
  56. def load_config(self, config_info):
  57. if config_info.get('host'):
  58. self._host = config_info['host']
  59. logger.info('Set host to: %s', self._host)
  60. if config_info.get('port'):
  61. self._port = config_info['port']
  62. logger.info('Set port to: %s', self._port)
  63. if config_info.get('channel_count'):
  64. self.sample_params.channel_count = config_info['channel_count']
  65. logger.info('Set channel count to: %s',
  66. self.sample_params.channel_count)
  67. if config_info.get('channel_labels'):
  68. assert len( config_info['channel_labels']) == \
  69. self.sample_params.channel_count, \
  70. 'Mismatch of channel labels and channel count'
  71. self.sample_params.channel_labels = config_info['channel_labels']
  72. logger.info('Set channel labels to: %s',
  73. self.sample_params.channel_labels)
  74. if config_info.get('sample_rate'):
  75. self.sample_params.sample_rate = config_info['sample_rate']
  76. logger.info('Set sample rate to: %s',
  77. self.sample_params.sample_rate)
  78. if config_info.get('delay_milliseconds'):
  79. self.sample_params.delay_milliseconds = config_info[
  80. 'delay_milliseconds']
  81. logger.info('Set delay milliseconds to: %s',
  82. self.sample_params.delay_milliseconds)
  83. # NOTICE: 放在最后执行,以确保更改对buffer生效
  84. self._addr = (self._host, self._port)
  85. self.sample_params.refresh()
  86. def is_connected(self):
  87. return self._is_connected
  88. def get_ready(self):
  89. self._sock = socket.socket()
  90. try:
  91. self._sock.connect(self._addr)
  92. self._is_connected = True
  93. self._sock.sendall(bytes('start', encoding='utf-8'))
  94. except ConnectionRefusedError:
  95. return False
  96. return True
  97. def setup_wave_mode(self):
  98. return True
  99. def setup_impedance_mode(self):
  100. return False
  101. def receive_wave(self):
  102. try:
  103. packet = self._sock.recv(self.sample_params.buffer_size)
  104. # timestamp = struct.unpack_from("d", packet[:2])
  105. packet_parse = np.frombuffer(packet, dtype=np.float32)
  106. data_block = packet_parse[2:].reshape(
  107. self.sample_params.channel_count,
  108. self.sample_params.data_count_per_channel)
  109. self._add_a_data_block_to_buffer(data_block)
  110. return True
  111. except ConnectionAbortedError:
  112. return False
  113. except IOError:
  114. return False
  115. def receive_impedance(self):
  116. raise NotImplementedError
  117. def _add_a_data_block_to_buffer(self, data_block: np.ndarray):
  118. self._timestamp += int(1000 *
  119. self.sample_params.data_count_per_channel /
  120. self.sample_params.sample_rate)
  121. data_block_in_buffer = DataBlockInBuf(data_block, self._timestamp)
  122. self._save_data_when_buffer_full(data_block_in_buffer)
  123. self.notify_observers(data_block_in_buffer)
  124. return data_block
  125. def stop(self):
  126. if self._sock:
  127. self._sock.close()
  128. self._is_connected = False
  129. self._timestamp = 0
  130. if self.saver and self.saver.is_ready:
  131. self.saver.close_edf_file()
  132. def notify_observers(self, data_block):
  133. for obj in self._observers:
  134. obj.update(data_block)
  135. def restart_wave(self):
  136. self._sock.sendall(bytes('restart', encoding='utf-8'))