test_neo.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """Module tests/core/sig_chain/device/test_neo provide test for neo connector"""
  2. import pytest
  3. import struct
  4. import unittest
  5. from unittest.mock import MagicMock
  6. from unittest.mock import patch
  7. import numpy as np
  8. from device.sig_chain.device.neo import bytes_to_float32
  9. from device.sig_chain.device.neo import NeoConnector
  10. TASK_PER_RUN = 1
  11. def teardown_function():
  12. NeoConnector.clear_instance()
  13. def gen_fake_recv_data(data_count_per_channel, channel_count):
  14. # 假的接收数据
  15. recv_data = np.ones((data_count_per_channel, channel_count),
  16. dtype=np.float32)
  17. for ii in range(channel_count):
  18. recv_data[:, ii] = (ii + 1) * recv_data[:, ii]
  19. return recv_data
  20. # ===================
  21. def test_new_connector_is_disconnected():
  22. connector = NeoConnector()
  23. assert not connector.is_connected()
  24. def test_new_connector_receive_wave_failed():
  25. connector = NeoConnector()
  26. with pytest.raises(Exception):
  27. connector.receive_wave()
  28. def test_after_get_ready_is_connected():
  29. connector = NeoConnector()
  30. mock_socket = MagicMock()
  31. mock_socket.connect.return_value = True
  32. with patch('socket.socket', mock_socket):
  33. success = connector.get_ready()
  34. assert success
  35. assert connector.is_connected()
  36. @unittest.skip('未实现')
  37. def test_after_get_ready_skip_connect_request():
  38. connector = NeoConnector()
  39. connector.get_ready()
  40. success = connector.get_ready()
  41. assert success
  42. def test_after_connected_receive_wave_success():
  43. connector = NeoConnector()
  44. recv_data = gen_fake_recv_data(
  45. connector.sample_params.data_count_per_channel,
  46. connector.sample_params.channel_count)
  47. def side_effect(arg): # 用于确认接收到的参数
  48. assert (arg == recv_data.T).all()
  49. connector._add_a_data_block_to_buffer = MagicMock(side_effect=side_effect)
  50. mock_socket = MagicMock()
  51. mock_socket.connect.return_value = True
  52. mock_socket.recv.return_value = recv_data.tobytes() #b''
  53. connector._sock = mock_socket
  54. success = connector.receive_wave()
  55. assert success
  56. def test_after_stop_is_disconnected():
  57. connector = NeoConnector()
  58. mock_socket = MagicMock()
  59. mock_socket.connect.return_value = True
  60. mock_socket.close = MagicMock()
  61. with patch('socket.socket', mock_socket):
  62. connector.get_ready()
  63. connector.stop()
  64. assert not connector.is_connected()
  65. def test_load_partial_config_success():
  66. connector = NeoConnector()
  67. mock_config = {
  68. 'host': '1.0.0.1'
  69. }
  70. connector.load_config(mock_config)
  71. assert connector._host == mock_config['host']
  72. def test_after_set_saver_buffer_is_set():
  73. connector = NeoConnector()
  74. connector.set_saver()
  75. assert connector.buffer_save is not None
  76. def test_before_set_edf_header_save_data_not_called():
  77. connector = NeoConnector()
  78. connector.set_saver()
  79. mock_save_raw_data = MagicMock()
  80. connector.saver.save_raw_data = mock_save_raw_data
  81. recv_data = gen_fake_recv_data(
  82. connector.sample_params.data_count_per_channel,
  83. connector.sample_params.channel_count)
  84. mock_socket = MagicMock()
  85. mock_socket.connect.return_value = True
  86. mock_socket.recv.return_value = recv_data.tobytes()
  87. connector._sock = mock_socket
  88. connector.receive_wave()
  89. assert not mock_save_raw_data.called
  90. def test_after_receive_wave_observers_are_notified():
  91. connector = NeoConnector()
  92. recv_data = gen_fake_recv_data(
  93. connector.sample_params.data_count_per_channel,
  94. connector.sample_params.channel_count)
  95. mock_socket = MagicMock()
  96. mock_socket.connect.return_value = True
  97. mock_socket.recv.return_value = recv_data.tobytes()
  98. connector._sock = mock_socket
  99. connector._save_data_when_buffer_full = MagicMock()
  100. mock_notify_observers = MagicMock()
  101. connector.notify_observers = mock_notify_observers
  102. connector.receive_wave()
  103. assert mock_notify_observers.called
  104. def test_with_matched_packet_bytes_to_float32_success():
  105. expected = [12.0, 0.0, -12398.1982421875, 34567.98828125]
  106. packet = b''
  107. for value in expected:
  108. packet += struct.pack('f', value)
  109. result = bytes_to_float32(packet, len(packet), 4)
  110. assert expected == result
  111. def test_mismatched_packet_bytes_to_float32_failed():
  112. expected = [12.0, 0.0, -12398.1982421875, 34567.98828125]
  113. packet = b''
  114. for value in expected:
  115. packet += struct.pack('f', value)
  116. packet = packet[:-2]
  117. with pytest.raises(AssertionError):
  118. bytes_to_float32(packet, len(packet), 4)
  119. def test_main():
  120. # pylint: disable=import-outside-toplevel
  121. from schemas.subjects import SubjectCreate
  122. # pylint: enable=import-outside-toplevel
  123. connector = NeoConnector()
  124. connector.set_saver()
  125. subject = SubjectCreate(name='nobody',
  126. id_card='12345',
  127. gender='男',
  128. birthday='1988-01-01',
  129. rehabilitation_parts=['左手'])
  130. connector.saver.set_edf_header(subject, 'filename.bdf', TASK_PER_RUN, '.')
  131. if connector.get_ready():
  132. for _ in range(20):
  133. connector.receive_wave()
  134. connector.stop()