"""接收假数据 Typical usage example: connector = FakerConnector() if connector.get_ready(): for _ in range(20): connector.receive_wave() connector.stop() """ import logging import numpy as np import socket from core.sig_chain.device.connector_interface import Connector from core.sig_chain.device.connector_interface import DataBlockInBuf from core.sig_chain.device.connector_interface import Device from core.sig_chain.utils import Observable from core.sig_chain.utils import Singleton logger = logging.getLogger(__name__) class SampleParams: def __init__(self, channel_count, sample_rate, delay_milliseconds): self.channel_count = channel_count self.channel_labels = [ 'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3', 'C3', 'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2', 'T5', 'P3' ][:self.channel_count] # montage 中定义的通道类型 self.channel_types = (['eeg'] * 24)[:self.channel_count] self.sample_rate = sample_rate self.delay_milliseconds = delay_milliseconds self.point_size = 4 self.timestamp_size = 8 self.data_count_per_channel = int(self.delay_milliseconds * self.sample_rate / 1000) self.data_block_size = self.channel_count * self.data_count_per_channel self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size self.physical_max = 20000 self.physical_min = -20000 def refresh(self): self.data_count_per_channel = int(self.delay_milliseconds * self.sample_rate / 1000) self.data_block_size = self.channel_count * self.data_count_per_channel self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size class FakerConnector(Connector, Singleton, Observable): def __init__(self) -> None: Observable.__init__(self) self.device = Device.FAKER self._host = '127.0.0.1' self._port = 21112 self._addr = (self._host, self._port) self._sock = None self._timestamp = 0 self.sample_params = SampleParams(24, 1000, 250) self._is_connected = False self.buffer_save = None self.saver = None def load_config(self, config_info): if config_info.get('host'): self._host = config_info['host'] logger.info('Set host to: %s', self._host) if config_info.get('port'): self._port = config_info['port'] logger.info('Set port to: %s', self._port) if config_info.get('channel_count'): self.sample_params.channel_count = config_info['channel_count'] logger.info('Set channel count to: %s', self.sample_params.channel_count) if config_info.get('channel_labels'): assert len( config_info['channel_labels']) == \ self.sample_params.channel_count, \ 'Mismatch of channel labels and channel count' self.sample_params.channel_labels = config_info['channel_labels'] logger.info('Set channel labels to: %s', self.sample_params.channel_labels) if config_info.get('sample_rate'): self.sample_params.sample_rate = config_info['sample_rate'] logger.info('Set sample rate to: %s', self.sample_params.sample_rate) if config_info.get('delay_milliseconds'): self.sample_params.delay_milliseconds = config_info[ 'delay_milliseconds'] logger.info('Set delay milliseconds to: %s', self.sample_params.delay_milliseconds) # NOTICE: 放在最后执行,以确保更改对buffer生效 self._addr = (self._host, self._port) self.sample_params.refresh() def is_connected(self): return self._is_connected def get_ready(self): self._sock = socket.socket() try: self._sock.connect(self._addr) self._is_connected = True self._sock.sendall(bytes('start', encoding='utf-8')) except ConnectionRefusedError: return False return True def setup_wave_mode(self): return True def setup_impedance_mode(self): return False def receive_wave(self): try: packet = self._sock.recv(self.sample_params.buffer_size) # timestamp = struct.unpack_from("d", packet[:2]) packet_parse = np.frombuffer(packet, dtype=np.float32) data_block = packet_parse[2:].reshape( self.sample_params.channel_count, self.sample_params.data_count_per_channel) self._add_a_data_block_to_buffer(data_block) return True except ConnectionAbortedError: return False except IOError: return False def receive_impedance(self): raise NotImplementedError def _add_a_data_block_to_buffer(self, data_block: np.ndarray): self._timestamp += int(1000 * self.sample_params.data_count_per_channel / self.sample_params.sample_rate) data_block_in_buffer = DataBlockInBuf(data_block, self._timestamp) self._save_data_when_buffer_full(data_block_in_buffer) self.notify_observers(data_block_in_buffer) return data_block def stop(self): if self._sock: self._sock.close() self._is_connected = False self._timestamp = 0 if self.saver and self.saver.is_ready: self.saver.close_edf_file() def notify_observers(self, data_block): for obj in self._observers: obj.update(data_block) def restart_wave(self): self._sock.sendall(bytes('restart', encoding='utf-8'))