"""连接多种脑电设备,并对接收的数据进行预处理 Typical usage example: receiver = Receiver() receiver.select_connector(Device.PONY) if receiver.setup_connector(): receiver.start_receive_wave() data_from_buffer = receiver.get_data_from_buffer('plot') receiver.stop_receive() """ import threading import time from core.sig_chain.device import connector_factory as cf from core.sig_chain.device.connector_interface import DataMode from core.sig_chain.device.connector_interface import Device from core.sig_chain.sig_buffer import ParserNewsetWithTime from core.sig_chain.sig_buffer import CircularBuffer from core.sig_chain.utils import Singleton class Receiver(Singleton): def __init__(self) -> None: if Receiver._init_flag: return Receiver._init_flag = True self.connector_factory = cf.ConnectorFactory() self.connector = None self.is_ready = False self.trial_num = 0 # TODO: 是否保留 self.buffer_plot = None self.buffer_classify_online = None self.lock = threading.Lock() def select_connector(self, device: Device, buffer_plot_size_seconds: float, config_info: dict = None): self.connector = self.connector_factory.create_connector(device) if config_info: self.connector.load_config(config_info) # NOTICE: 放在load_config最后执行,以确保更改对buffer等生效 self.setup_buffers(buffer_plot_size_seconds) def setup_buffers(self, buffer_plot_size_seconds): BUFFER_CLASSIFY_ONLINE_SIZE_SECONDS = 1 # pylint: disable=line-too-long assert buffer_plot_size_seconds * 1000 >= self.connector.sample_params.delay_milliseconds, \ 'Buffer size >= delay_milliseconds must be satisfied!' assert BUFFER_CLASSIFY_ONLINE_SIZE_SECONDS * 1000 >= self.connector.sample_params.delay_milliseconds, \ 'Buffer size >= delay_milliseconds must be satisfied!' # pylint: enable=line-too-long parser = ParserNewsetWithTime() self.buffer_plot = CircularBuffer( buffer_plot_size_seconds, self.connector.sample_params.data_count_per_channel / self.connector.sample_params.sample_rate, self.connector.sample_params.channel_labels, self.connector.sample_params.channel_types, self.connector.sample_params.sample_rate, parser) self.buffer_classify_online = CircularBuffer( BUFFER_CLASSIFY_ONLINE_SIZE_SECONDS, self.connector.sample_params.data_count_per_channel / self.connector.sample_params.sample_rate, self.connector.sample_params.channel_labels, self.connector.sample_params.channel_types, self.connector.sample_params.sample_rate, parser) self.connector.add_observer(self.buffer_plot) self.connector.add_observer(self.buffer_classify_online) def setup_connector(self): assert self.connector is not None, 'Select a connector first!' self.clear_all_buffer() self.is_ready = self.connector.get_ready() return self.is_ready def clear_all_buffer(self): if self.buffer_plot: self.buffer_plot.content.clear() if self.buffer_classify_online: self.buffer_classify_online.content.clear() def setup_receive_mode(self, mode: DataMode): success = False if mode == DataMode.WAVE: self.clear_all_buffer() success = self.connector.setup_wave_mode() else: success = self.connector.setup_impedance_mode() self.is_ready = success return success def start_receive_wave(self): assert self.is_ready, 'Receiver is not ready!' task = threading.Thread(target=self.receive_wave, args=(True,)) task.start() def receive_wave(self, need_lock=False): """ Args: need_lock:是否需要加锁,用于pony,因为直接调用这个函数是不需要加锁的; 而这个函数在另一个线程中执行时是需要加锁的 Returns: """ while self.is_ready: time.sleep(0.01) if need_lock: self.lock.acquire() self.connector.receive_wave() if need_lock: self.lock.release() def receive_impedance(self): assert self.is_ready, 'Receiver is not ready!' return self.connector.receive_impedance() def stop_receive(self, need_lock=False): """ Args: need_lock:是否需要加锁,用于pony,因为如果不使用多线程接收数据, 那么停止设备时就不需要加锁 Returns: """ if self.is_ready: self.is_ready = False if need_lock: self.lock.acquire() self.connector.stop() if need_lock: self.lock.release() def get_data_from_buffer(self, buffer_type: str, data_format='mne'): if not self.is_ready: raise RuntimeError('Connecter has not been setup correctly !') assert buffer_type in ['plot', 'resting_state', 'classify_online'], \ 'Invalid buffer type' if buffer_type == 'plot': return self.buffer_plot.get_sig(data_format) elif buffer_type == 'classify_online': return self.buffer_classify_online.get_sig(data_format) def reset_wave(self): self.clear_all_buffer() self.connector.restart_wave()