"""接收neo软件转发的数据 Typical usage example: connector = NeoConnector() if connector.get_ready(): for _ in range(20): connector.receive_wave() connector.stop() """ import logging import socket import struct import numpy as np from core.sig_chain.device.connector_interface import Connector from core.sig_chain.device.connector_interface import DataBlockInBuf from core.sig_chain.device.connector_interface import Device from core.sig_chain.utils import Observable from core.sig_chain.utils import Singleton logger = logging.getLogger(__name__) def bytes_to_float32(packet: bytes, bytes_of_packet, bytes_per_point=4): assert bytes_of_packet % bytes_per_point == 0, \ 'Bytes_of_packet % Bytes_per_point != 0' data_block = [] for ii in range( bytes_of_packet // bytes_per_point): point_in_bytes = packet[ii * bytes_per_point:(ii + 1) * bytes_per_point] value = struct.unpack('f', point_in_bytes)[0] data_block.append(value) return data_block class SampleParams: def __init__(self): self.channel_count = 9 self.channel_labels = [ 'C3', 'FC3', 'CP5', 'CP1', 'C4', 'FC4', 'CP2', 'CP6', 'Fp1' ][:self.channel_count] # montage 中定义的通道类型 self.channel_types = (['eeg'] * 8 + ['misc'])[:self.channel_count] self.sample_rate = 1000 # TODO: fixed? self.data_count_per_channel = int(40 * self.sample_rate / 1000) self.point_size = 4 # channel: 8 + 1, 一个包传40个点, float: 4 字节; 9 * 40 * 4 = 1440 self.buffer_size = \ self.channel_count * self.data_count_per_channel * self.point_size self.data_block_size = self.channel_count * self.data_count_per_channel # 设备将数据量化的物理数值区间 self.physical_max = 200000 self.physical_min = -200000 self.delay_milliseconds = int(self.data_count_per_channel / self.sample_rate * 1000) def refresh(self): self.data_count_per_channel = int(40 * self.sample_rate / 1000) self.delay_milliseconds = int(self.data_count_per_channel / self.sample_rate * 1000) self.buffer_size = \ self.channel_count * self.data_count_per_channel * self.point_size self.data_block_size = self.channel_count * self.data_count_per_channel class NeoConnector(Connector, Singleton, Observable): def __init__(self) -> None: Observable.__init__(self) self.device = Device.NEO self._host = '127.0.0.1' self._port = 8712 self._addr = (self._host, self._port) self._sock = None self._timestamp = 0 self.sample_params = SampleParams() self._is_connected = False self.buffer_save = None self.saver = None def load_config(self, config_info): if config_info.get('host'): self._host = config_info['host'] logger.info('Set host to: %s', self._host) if config_info.get('port'): self._port = config_info['port'] logger.info('Set port to: %s', self._port) if config_info.get('channel_count'): self.sample_params.channel_count = config_info['channel_count'] logger.info('Set channel count to: %s', self.sample_params.channel_count) if config_info.get('channel_labels'): assert len( config_info['channel_labels']) == \ self.sample_params.channel_count, \ 'Mismatch of channel labels and channel count' self.sample_params.channel_labels = config_info['channel_labels'] logger.info('Set channel labels to: %s', self.sample_params.channel_labels) if config_info.get('sample_rate'): self.sample_params.sample_rate = config_info['sample_rate'] logger.info('Set sample rate to: %s', self.sample_params.sample_rate) # NOTICE: 放在最后执行,以确保更改对相关参数生效 self.sample_params.refresh() self._addr = (self._host, self._port) def is_connected(self): return self._is_connected def get_ready(self): self._sock = socket.socket() try: self._sock.connect(self._addr) self._is_connected = True except ConnectionRefusedError: return False return True def setup_wave_mode(self): return True def setup_impedance_mode(self): return False def receive_wave(self): try: packet = self._sock.recv(self.sample_params.buffer_size) data_block = np.frombuffer(packet, dtype=np.float32).reshape( self.sample_params.data_count_per_channel, self.sample_params.channel_count).T self._add_a_data_block_to_buffer(data_block) return True except ConnectionAbortedError: return False except OSError: return False except ValueError: return False def receive_impedance(self): raise NotImplementedError def _add_a_data_block_to_buffer(self, data_block: np.ndarray): self._timestamp += int(1000 * self.sample_params.data_count_per_channel / self.sample_params.sample_rate) data_block_in_buffer = DataBlockInBuf(data_block, self._timestamp) self._save_data_when_buffer_full(data_block_in_buffer) self.notify_observers(data_block_in_buffer) # return data_block_2d def stop(self): if self._sock: self._sock.close() self._is_connected = False self._timestamp = 0 if self.saver and self.saver.is_ready: self.saver.close_edf_file() def notify_observers(self, data_block): for obj in self._observers: obj.update(data_block)