neo.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """接收neo软件转发的数据
  2. Typical usage example:
  3. connector = NeoConnector()
  4. if connector.get_ready():
  5. for _ in range(20):
  6. connector.receive_wave()
  7. connector.stop()
  8. """
  9. import logging
  10. import socket
  11. import struct
  12. import numpy as np
  13. from core.sig_chain.device.connector_interface import Connector
  14. from core.sig_chain.device.connector_interface import DataBlockInBuf
  15. from core.sig_chain.device.connector_interface import Device
  16. from core.sig_chain.utils import Observable
  17. from core.sig_chain.utils import Singleton
  18. logger = logging.getLogger(__name__)
  19. def bytes_to_float32(packet: bytes, bytes_of_packet, bytes_per_point=4):
  20. assert bytes_of_packet % bytes_per_point == 0, \
  21. 'Bytes_of_packet % Bytes_per_point != 0'
  22. data_block = []
  23. for ii in range( bytes_of_packet // bytes_per_point):
  24. point_in_bytes = packet[ii * bytes_per_point:(ii + 1) * bytes_per_point]
  25. value = struct.unpack('f', point_in_bytes)[0]
  26. data_block.append(value)
  27. return data_block
  28. class SampleParams:
  29. def __init__(self):
  30. self.channel_count = 9
  31. self.channel_labels = [
  32. 'C3', 'FC3', 'CP5', 'CP1', 'C4', 'FC4', 'CP2', 'CP6', 'Fp1'
  33. ][:self.channel_count]
  34. # montage 中定义的通道类型
  35. self.channel_types = (['eeg'] * 8 +
  36. ['misc'])[:self.channel_count]
  37. self.sample_rate = 1000 # TODO: fixed?
  38. self.data_count_per_channel = int(40 * self.sample_rate / 1000)
  39. self.point_size = 4
  40. # channel: 8 + 1, 一个包传40个点, float: 4 字节; 9 * 40 * 4 = 1440
  41. self.buffer_size = \
  42. self.channel_count * self.data_count_per_channel * self.point_size
  43. self.data_block_size = self.channel_count * self.data_count_per_channel
  44. # 设备将数据量化的物理数值区间
  45. self.physical_max = 200000
  46. self.physical_min = -200000
  47. self.delay_milliseconds = int(self.data_count_per_channel /
  48. self.sample_rate * 1000)
  49. def refresh(self):
  50. self.data_count_per_channel = int(40 * self.sample_rate / 1000)
  51. self.delay_milliseconds = int(self.data_count_per_channel /
  52. self.sample_rate * 1000)
  53. self.buffer_size = \
  54. self.channel_count * self.data_count_per_channel * self.point_size
  55. self.data_block_size = self.channel_count * self.data_count_per_channel
  56. class NeoConnector(Connector, Singleton, Observable):
  57. def __init__(self) -> None:
  58. Observable.__init__(self)
  59. self.device = Device.NEO
  60. self._host = '127.0.0.1'
  61. self._port = 8712
  62. self._addr = (self._host, self._port)
  63. self._sock = None
  64. self._timestamp = 0
  65. self.sample_params = SampleParams()
  66. self._is_connected = False
  67. self.buffer_save = None
  68. self.saver = None
  69. def load_config(self, config_info):
  70. if config_info.get('host'):
  71. self._host = config_info['host']
  72. logger.info('Set host to: %s', self._host)
  73. if config_info.get('port'):
  74. self._port = config_info['port']
  75. logger.info('Set port to: %s', self._port)
  76. if config_info.get('channel_count'):
  77. self.sample_params.channel_count = config_info['channel_count']
  78. logger.info('Set channel count to: %s',
  79. self.sample_params.channel_count)
  80. if config_info.get('channel_labels'):
  81. assert len( config_info['channel_labels']) == \
  82. self.sample_params.channel_count, \
  83. 'Mismatch of channel labels and channel count'
  84. self.sample_params.channel_labels = config_info['channel_labels']
  85. logger.info('Set channel labels to: %s',
  86. self.sample_params.channel_labels)
  87. if config_info.get('sample_rate'):
  88. self.sample_params.sample_rate = config_info['sample_rate']
  89. logger.info('Set sample rate to: %s',
  90. self.sample_params.sample_rate)
  91. # NOTICE: 放在最后执行,以确保更改对相关参数生效
  92. self.sample_params.refresh()
  93. self._addr = (self._host, self._port)
  94. def is_connected(self):
  95. return self._is_connected
  96. def get_ready(self):
  97. self._sock = socket.socket()
  98. try:
  99. self._sock.connect(self._addr)
  100. self._is_connected = True
  101. except ConnectionRefusedError:
  102. return False
  103. return True
  104. def setup_wave_mode(self):
  105. return True
  106. def setup_impedance_mode(self):
  107. return False
  108. def receive_wave(self):
  109. try:
  110. packet = self._sock.recv(self.sample_params.buffer_size)
  111. data_block = np.frombuffer(packet, dtype=np.float32).reshape(
  112. self.sample_params.data_count_per_channel,
  113. self.sample_params.channel_count).T
  114. self._add_a_data_block_to_buffer(data_block)
  115. return True
  116. except ConnectionAbortedError:
  117. return False
  118. except OSError:
  119. return False
  120. except ValueError:
  121. return False
  122. def receive_impedance(self):
  123. raise NotImplementedError
  124. def _add_a_data_block_to_buffer(self, data_block: np.ndarray):
  125. self._timestamp += int(1000 *
  126. self.sample_params.data_count_per_channel /
  127. self.sample_params.sample_rate)
  128. data_block_in_buffer = DataBlockInBuf(data_block,
  129. self._timestamp)
  130. self._save_data_when_buffer_full(data_block_in_buffer)
  131. self.notify_observers(data_block_in_buffer)
  132. # return data_block_2d
  133. def stop(self):
  134. if self._sock:
  135. self._sock.close()
  136. self._is_connected = False
  137. self._timestamp = 0
  138. if self.saver and self.saver.is_ready:
  139. self.saver.close_edf_file()
  140. def notify_observers(self, data_block):
  141. for obj in self._observers:
  142. obj.update(data_block)