123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- """Module tests/core/sig_chain/device/test_neo provide test for neo connector"""
- import pytest
- import struct
- import unittest
- from unittest.mock import MagicMock
- from unittest.mock import patch
- import numpy as np
- from device.sig_chain.device.neo import bytes_to_float32
- from device.sig_chain.device.neo import NeoConnector
- TASK_PER_RUN = 1
- def teardown_function():
- NeoConnector.clear_instance()
- def gen_fake_recv_data(data_count_per_channel, channel_count):
- # 假的接收数据
- recv_data = np.ones((data_count_per_channel, channel_count),
- dtype=np.float32)
- for ii in range(channel_count):
- recv_data[:, ii] = (ii + 1) * recv_data[:, ii]
- return recv_data
- # ===================
- def test_new_connector_is_disconnected():
- connector = NeoConnector()
- assert not connector.is_connected()
- def test_new_connector_receive_wave_failed():
- connector = NeoConnector()
- with pytest.raises(Exception):
- connector.receive_wave()
- def test_after_get_ready_is_connected():
- connector = NeoConnector()
- mock_socket = MagicMock()
- mock_socket.connect.return_value = True
- with patch('socket.socket', mock_socket):
- success = connector.get_ready()
- assert success
- assert connector.is_connected()
- @unittest.skip('未实现')
- def test_after_get_ready_skip_connect_request():
- connector = NeoConnector()
- connector.get_ready()
- success = connector.get_ready()
- assert success
- def test_after_connected_receive_wave_success():
- connector = NeoConnector()
- recv_data = gen_fake_recv_data(
- connector.sample_params.data_count_per_channel,
- connector.sample_params.channel_count)
- def side_effect(arg): # 用于确认接收到的参数
- assert (arg == recv_data.T).all()
- connector._add_a_data_block_to_buffer = MagicMock(side_effect=side_effect)
- mock_socket = MagicMock()
- mock_socket.connect.return_value = True
- mock_socket.recv.return_value = recv_data.tobytes() #b''
- connector._sock = mock_socket
- success = connector.receive_wave()
- assert success
- def test_after_stop_is_disconnected():
- connector = NeoConnector()
- mock_socket = MagicMock()
- mock_socket.connect.return_value = True
- mock_socket.close = MagicMock()
- with patch('socket.socket', mock_socket):
- connector.get_ready()
- connector.stop()
- assert not connector.is_connected()
- def test_load_partial_config_success():
- connector = NeoConnector()
- mock_config = {
- 'host': '1.0.0.1'
- }
- connector.load_config(mock_config)
- assert connector._host == mock_config['host']
- def test_after_set_saver_buffer_is_set():
- connector = NeoConnector()
- connector.set_saver()
- assert connector.buffer_save is not None
- def test_before_set_edf_header_save_data_not_called():
- connector = NeoConnector()
- connector.set_saver()
- mock_save_raw_data = MagicMock()
- connector.saver.save_raw_data = mock_save_raw_data
- recv_data = gen_fake_recv_data(
- connector.sample_params.data_count_per_channel,
- connector.sample_params.channel_count)
- mock_socket = MagicMock()
- mock_socket.connect.return_value = True
- mock_socket.recv.return_value = recv_data.tobytes()
- connector._sock = mock_socket
- connector.receive_wave()
- assert not mock_save_raw_data.called
- def test_after_receive_wave_observers_are_notified():
- connector = NeoConnector()
- recv_data = gen_fake_recv_data(
- connector.sample_params.data_count_per_channel,
- connector.sample_params.channel_count)
- mock_socket = MagicMock()
- mock_socket.connect.return_value = True
- mock_socket.recv.return_value = recv_data.tobytes()
- connector._sock = mock_socket
- connector._save_data_when_buffer_full = MagicMock()
- mock_notify_observers = MagicMock()
- connector.notify_observers = mock_notify_observers
- connector.receive_wave()
- assert mock_notify_observers.called
- def test_with_matched_packet_bytes_to_float32_success():
- expected = [12.0, 0.0, -12398.1982421875, 34567.98828125]
- packet = b''
- for value in expected:
- packet += struct.pack('f', value)
- result = bytes_to_float32(packet, len(packet), 4)
- assert expected == result
- def test_mismatched_packet_bytes_to_float32_failed():
- expected = [12.0, 0.0, -12398.1982421875, 34567.98828125]
- packet = b''
- for value in expected:
- packet += struct.pack('f', value)
- packet = packet[:-2]
- with pytest.raises(AssertionError):
- bytes_to_float32(packet, len(packet), 4)
- def test_main():
- # pylint: disable=import-outside-toplevel
- from schemas.subjects import SubjectCreate
- # pylint: enable=import-outside-toplevel
- connector = NeoConnector()
- connector.set_saver()
- subject = SubjectCreate(name='nobody',
- id_card='12345',
- gender='男',
- birthday='1988-01-01',
- rehabilitation_parts=['左手'])
- connector.saver.set_edf_header(subject, 'filename.bdf', TASK_PER_RUN, '.')
- if connector.get_ready():
- for _ in range(20):
- connector.receive_wave()
- connector.stop()
|