|
@@ -1,163 +0,0 @@
|
|
|
-"""接收假数据
|
|
|
-
|
|
|
-Typical usage example:
|
|
|
-
|
|
|
- connector = FakerConnector()
|
|
|
- if connector.get_ready():
|
|
|
- for _ in range(20):
|
|
|
- connector.receive_wave()
|
|
|
- connector.stop()
|
|
|
-"""
|
|
|
-import logging
|
|
|
-import numpy as np
|
|
|
-import socket
|
|
|
-
|
|
|
-from device.sig_chain.device.connector_interface import Connector
|
|
|
-from device.sig_chain.device.connector_interface import DataBlockInBuf
|
|
|
-from device.sig_chain.device.connector_interface import Device
|
|
|
-from device.sig_chain.utils import Observable
|
|
|
-from device.sig_chain.utils import Singleton
|
|
|
-
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
-
|
|
|
-
|
|
|
-class SampleParams:
|
|
|
-
|
|
|
- def __init__(self, channel_count, sample_rate, delay_milliseconds):
|
|
|
- self.channel_count = channel_count
|
|
|
- self.channel_labels = [
|
|
|
- 'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3',
|
|
|
- 'C3', 'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2',
|
|
|
- 'T5', 'P3'
|
|
|
- ][:self.channel_count]
|
|
|
- # montage 中定义的通道类型
|
|
|
- self.channel_types = (['eeg'] * 24)[:self.channel_count]
|
|
|
- self.sample_rate = sample_rate
|
|
|
- self.delay_milliseconds = delay_milliseconds
|
|
|
- self.point_size = 4
|
|
|
- self.timestamp_size = 8
|
|
|
- self.data_count_per_channel = int(self.delay_milliseconds *
|
|
|
- self.sample_rate / 1000)
|
|
|
- self.data_block_size = self.channel_count * self.data_count_per_channel
|
|
|
- self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size
|
|
|
- self.physical_max = 20000
|
|
|
- self.physical_min = -20000
|
|
|
-
|
|
|
- def refresh(self):
|
|
|
- self.data_count_per_channel = int(self.delay_milliseconds *
|
|
|
- self.sample_rate / 1000)
|
|
|
- self.data_block_size = self.channel_count * self.data_count_per_channel
|
|
|
- self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size
|
|
|
-
|
|
|
-
|
|
|
-class FakerConnector(Connector, Singleton, Observable):
|
|
|
-
|
|
|
- def __init__(self) -> None:
|
|
|
- Observable.__init__(self)
|
|
|
- self.device = Device.FAKER
|
|
|
- self._host = '127.0.0.1'
|
|
|
- self._port = 21112
|
|
|
- self._addr = (self._host, self._port)
|
|
|
- self._sock = None
|
|
|
- self._timestamp = 0
|
|
|
-
|
|
|
- self.sample_params = SampleParams(24, 1000, 250)
|
|
|
-
|
|
|
- self._is_connected = False
|
|
|
-
|
|
|
- self.buffer_save = None
|
|
|
- self.saver = None
|
|
|
-
|
|
|
- def load_config(self, config_info):
|
|
|
- if config_info.get('host'):
|
|
|
- self._host = config_info['host']
|
|
|
- logger.info('Set host to: %s', self._host)
|
|
|
- if config_info.get('port'):
|
|
|
- self._port = config_info['port']
|
|
|
- logger.info('Set port to: %s', self._port)
|
|
|
- if config_info.get('channel_count'):
|
|
|
- self.sample_params.channel_count = config_info['channel_count']
|
|
|
- logger.info('Set channel count to: %s',
|
|
|
- self.sample_params.channel_count)
|
|
|
- if config_info.get('channel_labels'):
|
|
|
- assert len( config_info['channel_labels']) == \
|
|
|
- self.sample_params.channel_count, \
|
|
|
- 'Mismatch of channel labels and channel count'
|
|
|
- self.sample_params.channel_labels = config_info['channel_labels']
|
|
|
- logger.info('Set channel labels to: %s',
|
|
|
- self.sample_params.channel_labels)
|
|
|
- if config_info.get('sample_rate'):
|
|
|
- self.sample_params.sample_rate = config_info['sample_rate']
|
|
|
- logger.info('Set sample rate to: %s',
|
|
|
- self.sample_params.sample_rate)
|
|
|
- if config_info.get('delay_milliseconds'):
|
|
|
- self.sample_params.delay_milliseconds = config_info[
|
|
|
- 'delay_milliseconds']
|
|
|
- logger.info('Set delay milliseconds to: %s',
|
|
|
- self.sample_params.delay_milliseconds)
|
|
|
- # NOTICE: 放在最后执行,以确保更改对buffer生效
|
|
|
- self._addr = (self._host, self._port)
|
|
|
- self.sample_params.refresh()
|
|
|
-
|
|
|
- def is_connected(self):
|
|
|
- return self._is_connected
|
|
|
-
|
|
|
- def get_ready(self):
|
|
|
- self._sock = socket.socket()
|
|
|
- try:
|
|
|
- self._sock.connect(self._addr)
|
|
|
- self._is_connected = True
|
|
|
- self._sock.sendall(bytes('start', encoding='utf-8'))
|
|
|
- except ConnectionRefusedError:
|
|
|
- return False
|
|
|
- return True
|
|
|
-
|
|
|
- def setup_wave_mode(self):
|
|
|
- return True
|
|
|
-
|
|
|
- def setup_impedance_mode(self):
|
|
|
- return False
|
|
|
-
|
|
|
- def receive_wave(self):
|
|
|
- try:
|
|
|
- packet = self._sock.recv(self.sample_params.buffer_size)
|
|
|
- # timestamp = struct.unpack_from("d", packet[:2])
|
|
|
- packet_parse = np.frombuffer(packet, dtype=np.float32)
|
|
|
- data_block = packet_parse[2:].reshape(
|
|
|
- self.sample_params.channel_count,
|
|
|
- self.sample_params.data_count_per_channel)
|
|
|
- self._add_a_data_block_to_buffer(data_block)
|
|
|
- return True
|
|
|
- except ConnectionAbortedError:
|
|
|
- return False
|
|
|
- except IOError:
|
|
|
- return False
|
|
|
-
|
|
|
- def receive_impedance(self):
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- def _add_a_data_block_to_buffer(self, data_block: np.ndarray):
|
|
|
- self._timestamp += int(1000 *
|
|
|
- self.sample_params.data_count_per_channel /
|
|
|
- self.sample_params.sample_rate)
|
|
|
- data_block_in_buffer = DataBlockInBuf(data_block, self._timestamp)
|
|
|
- self._save_data_when_buffer_full(data_block_in_buffer)
|
|
|
- self.notify_observers(data_block_in_buffer)
|
|
|
-
|
|
|
- return data_block
|
|
|
-
|
|
|
- def stop(self):
|
|
|
- if self._sock:
|
|
|
- self._sock.close()
|
|
|
- self._is_connected = False
|
|
|
- self._timestamp = 0
|
|
|
-
|
|
|
- if self.saver and self.saver.is_ready:
|
|
|
- self.saver.close_edf_file()
|
|
|
-
|
|
|
- def notify_observers(self, data_block):
|
|
|
- for obj in self._observers:
|
|
|
- obj.update(data_block)
|
|
|
-
|
|
|
- def restart_wave(self):
|
|
|
- self._sock.sendall(bytes('restart', encoding='utf-8'))
|