sig_receive.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """连接多种脑电设备,并对接收的数据进行预处理
  2. Typical usage example:
  3. receiver = Receiver()
  4. receiver.select_connector(Device.PONY)
  5. if receiver.setup_connector():
  6. receiver.start_receive_wave()
  7. data_from_buffer = receiver.get_data_from_buffer('plot')
  8. receiver.stop_receive()
  9. """
  10. import threading
  11. import time
  12. from core.sig_chain.device import connector_factory as cf
  13. from core.sig_chain.device.connector_interface import DataMode
  14. from core.sig_chain.device.connector_interface import Device
  15. from core.sig_chain.sig_buffer import ParserNewsetWithTime
  16. from core.sig_chain.sig_buffer import CircularBuffer
  17. from core.sig_chain.utils import Singleton
  18. class Receiver(Singleton):
  19. def __init__(self) -> None:
  20. if Receiver._init_flag:
  21. return
  22. Receiver._init_flag = True
  23. self.connector_factory = cf.ConnectorFactory()
  24. self.connector = None
  25. self.is_ready = False
  26. self.trial_num = 0 # TODO: 是否保留
  27. self.buffer_plot = None
  28. self.buffer_classify_online = None
  29. self.lock = threading.Lock()
  30. def select_connector(self,
  31. device: Device,
  32. buffer_plot_size_seconds: float,
  33. config_info: dict = None):
  34. self.connector = self.connector_factory.create_connector(device)
  35. if config_info:
  36. self.connector.load_config(config_info)
  37. # NOTICE: 放在load_config最后执行,以确保更改对buffer等生效
  38. self.setup_buffers(buffer_plot_size_seconds)
  39. def setup_buffers(self, buffer_plot_size_seconds):
  40. BUFFER_CLASSIFY_ONLINE_SIZE_SECONDS = 1
  41. # pylint: disable=line-too-long
  42. assert buffer_plot_size_seconds * 1000 >= self.connector.sample_params.delay_milliseconds, \
  43. 'Buffer size >= delay_milliseconds must be satisfied!'
  44. assert BUFFER_CLASSIFY_ONLINE_SIZE_SECONDS * 1000 >= self.connector.sample_params.delay_milliseconds, \
  45. 'Buffer size >= delay_milliseconds must be satisfied!'
  46. # pylint: enable=line-too-long
  47. parser = ParserNewsetWithTime()
  48. self.buffer_plot = CircularBuffer(
  49. buffer_plot_size_seconds,
  50. self.connector.sample_params.data_count_per_channel /
  51. self.connector.sample_params.sample_rate,
  52. self.connector.sample_params.channel_labels,
  53. self.connector.sample_params.channel_types,
  54. self.connector.sample_params.sample_rate, parser)
  55. self.buffer_classify_online = CircularBuffer(
  56. BUFFER_CLASSIFY_ONLINE_SIZE_SECONDS,
  57. self.connector.sample_params.data_count_per_channel /
  58. self.connector.sample_params.sample_rate,
  59. self.connector.sample_params.channel_labels,
  60. self.connector.sample_params.channel_types,
  61. self.connector.sample_params.sample_rate, parser)
  62. self.connector.add_observer(self.buffer_plot)
  63. self.connector.add_observer(self.buffer_classify_online)
  64. def setup_connector(self):
  65. assert self.connector is not None, 'Select a connector first!'
  66. self.clear_all_buffer()
  67. self.is_ready = self.connector.get_ready()
  68. return self.is_ready
  69. def clear_all_buffer(self):
  70. if self.buffer_plot:
  71. self.buffer_plot.content.clear()
  72. if self.buffer_classify_online:
  73. self.buffer_classify_online.content.clear()
  74. def setup_receive_mode(self, mode: DataMode):
  75. success = False
  76. if mode == DataMode.WAVE:
  77. self.clear_all_buffer()
  78. success = self.connector.setup_wave_mode()
  79. else:
  80. success = self.connector.setup_impedance_mode()
  81. self.is_ready = success
  82. return success
  83. def start_receive_wave(self):
  84. assert self.is_ready, 'Receiver is not ready!'
  85. task = threading.Thread(target=self.receive_wave, args=(True,))
  86. task.start()
  87. def receive_wave(self, need_lock=False):
  88. """
  89. Args:
  90. need_lock:是否需要加锁,用于pony,因为直接调用这个函数是不需要加锁的;
  91. 而这个函数在另一个线程中执行时是需要加锁的
  92. Returns:
  93. """
  94. while self.is_ready:
  95. time.sleep(0.01)
  96. if need_lock:
  97. self.lock.acquire()
  98. self.connector.receive_wave()
  99. if need_lock:
  100. self.lock.release()
  101. def receive_impedance(self):
  102. assert self.is_ready, 'Receiver is not ready!'
  103. return self.connector.receive_impedance()
  104. def stop_receive(self, need_lock=False):
  105. """
  106. Args:
  107. need_lock:是否需要加锁,用于pony,因为如果不使用多线程接收数据,
  108. 那么停止设备时就不需要加锁
  109. Returns:
  110. """
  111. if self.is_ready:
  112. self.is_ready = False
  113. if need_lock:
  114. self.lock.acquire()
  115. self.connector.stop()
  116. if need_lock:
  117. self.lock.release()
  118. def get_data_from_buffer(self, buffer_type: str, data_format='mne'):
  119. if not self.is_ready:
  120. raise RuntimeError('Connecter has not been setup correctly !')
  121. assert buffer_type in ['plot', 'resting_state', 'classify_online'], \
  122. 'Invalid buffer type'
  123. if buffer_type == 'plot':
  124. return self.buffer_plot.get_sig(data_format)
  125. elif buffer_type == 'classify_online':
  126. return self.buffer_classify_online.get_sig(data_format)
  127. def reset_wave(self):
  128. self.clear_all_buffer()
  129. self.connector.restart_wave()