123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- """接收neo软件转发的数据
- Typical usage example:
- connector = NeoConnector()
- if connector.get_ready():
- for _ in range(20):
- connector.receive_wave()
- connector.stop()
- """
- import logging
- import socket
- import struct
- import numpy as np
- from core.sig_chain.device.connector_interface import Connector
- from core.sig_chain.device.connector_interface import DataBlockInBuf
- from core.sig_chain.device.connector_interface import Device
- from core.sig_chain.utils import Observable
- from core.sig_chain.utils import Singleton
- logger = logging.getLogger(__name__)
- def bytes_to_float32(packet: bytes, bytes_of_packet, bytes_per_point=4):
- assert bytes_of_packet % bytes_per_point == 0, \
- 'Bytes_of_packet % Bytes_per_point != 0'
- data_block = []
- for ii in range( bytes_of_packet // bytes_per_point):
- point_in_bytes = packet[ii * bytes_per_point:(ii + 1) * bytes_per_point]
- value = struct.unpack('f', point_in_bytes)[0]
- data_block.append(value)
- return data_block
- class SampleParams:
- def __init__(self):
- self.channel_count = 9
- self.channel_labels = [
- 'C3', 'FC3', 'CP5', 'CP1', 'C4', 'FC4', 'CP2', 'CP6', 'Fp1'
- ][:self.channel_count]
- # montage 中定义的通道类型
- self.channel_types = (['eeg'] * 8 +
- ['misc'])[:self.channel_count]
- self.sample_rate = 1000 # TODO: fixed?
- self.data_count_per_channel = int(40 * self.sample_rate / 1000)
- self.point_size = 4
- # channel: 8 + 1, 一个包传40个点, float: 4 字节; 9 * 40 * 4 = 1440
- self.buffer_size = \
- self.channel_count * self.data_count_per_channel * self.point_size
- self.data_block_size = self.channel_count * self.data_count_per_channel
- # 设备将数据量化的物理数值区间
- self.physical_max = 200000
- self.physical_min = -200000
- self.delay_milliseconds = int(self.data_count_per_channel /
- self.sample_rate * 1000)
- def refresh(self):
- self.data_count_per_channel = int(40 * self.sample_rate / 1000)
- self.delay_milliseconds = int(self.data_count_per_channel /
- self.sample_rate * 1000)
- self.buffer_size = \
- self.channel_count * self.data_count_per_channel * self.point_size
- self.data_block_size = self.channel_count * self.data_count_per_channel
- class NeoConnector(Connector, Singleton, Observable):
- def __init__(self) -> None:
- Observable.__init__(self)
- self.device = Device.NEO
- self._host = '127.0.0.1'
- self._port = 8712
- self._addr = (self._host, self._port)
- self._sock = None
- self._timestamp = 0
- self.sample_params = SampleParams()
- self._is_connected = False
- self.buffer_save = None
- self.saver = None
- def load_config(self, config_info):
- if config_info.get('host'):
- self._host = config_info['host']
- logger.info('Set host to: %s', self._host)
- if config_info.get('port'):
- self._port = config_info['port']
- logger.info('Set port to: %s', self._port)
- if config_info.get('channel_count'):
- self.sample_params.channel_count = config_info['channel_count']
- logger.info('Set channel count to: %s',
- self.sample_params.channel_count)
- if config_info.get('channel_labels'):
- assert len( config_info['channel_labels']) == \
- self.sample_params.channel_count, \
- 'Mismatch of channel labels and channel count'
- self.sample_params.channel_labels = config_info['channel_labels']
- logger.info('Set channel labels to: %s',
- self.sample_params.channel_labels)
- if config_info.get('sample_rate'):
- self.sample_params.sample_rate = config_info['sample_rate']
- logger.info('Set sample rate to: %s',
- self.sample_params.sample_rate)
- # NOTICE: 放在最后执行,以确保更改对相关参数生效
- self.sample_params.refresh()
- self._addr = (self._host, self._port)
- def is_connected(self):
- return self._is_connected
- def get_ready(self):
- self._sock = socket.socket()
- try:
- self._sock.connect(self._addr)
- self._is_connected = True
- except ConnectionRefusedError:
- return False
- return True
- def setup_wave_mode(self):
- return True
- def setup_impedance_mode(self):
- return False
- def receive_wave(self):
- try:
- packet = self._sock.recv(self.sample_params.buffer_size)
- data_block = np.frombuffer(packet, dtype=np.float32).reshape(
- self.sample_params.data_count_per_channel,
- self.sample_params.channel_count).T
- self._add_a_data_block_to_buffer(data_block)
- return True
- except ConnectionAbortedError:
- return False
- except OSError:
- return False
- except ValueError:
- return False
- def receive_impedance(self):
- raise NotImplementedError
- def _add_a_data_block_to_buffer(self, data_block: np.ndarray):
- self._timestamp += int(1000 *
- self.sample_params.data_count_per_channel /
- self.sample_params.sample_rate)
- data_block_in_buffer = DataBlockInBuf(data_block,
- self._timestamp)
- self._save_data_when_buffer_full(data_block_in_buffer)
- self.notify_observers(data_block_in_buffer)
- # return data_block_2d
- def stop(self):
- if self._sock:
- self._sock.close()
- self._is_connected = False
- self._timestamp = 0
- if self.saver and self.saver.is_ready:
- self.saver.close_edf_file()
- def notify_observers(self, data_block):
- for obj in self._observers:
- obj.update(data_block)
|