Browse Source

Refactor: 删除无用代码

dk 1 year ago
parent
commit
bce2fafa55

+ 4 - 1
.gitignore

@@ -12,7 +12,10 @@ __pycache__/
 # projects
 backend/db/data/
 backend/static/video/
-backend/static/images/
+backend/static/images/*
+!backend/static/images/.gitkeep
+backend/static/models/*
+!backend/static/models/.gitkeep
 backend/logs/
 node_modules/
 

+ 0 - 0
backend/db/models/subject.py → backend/db/subject.py


+ 0 - 0
backend/db/models/test.py → backend/db/test.py


+ 0 - 0
backend/db/models/train.py → backend/db/train.py


+ 2 - 8
backend/device/sig_chain/sig_save.py

@@ -59,7 +59,7 @@ class SigSave():
         """ 用于设置EDF头部信息
 
         Args:
-            subject (class): 受试数据库实体
+            subject (str): subject name
             path (str): 存储数据路径
         """
         channel_info = []
@@ -80,13 +80,7 @@ class SigSave():
             }
             channel_info.append(ch_dict)
         self.edf_w.setSignalHeaders(channel_info)
-        self.edf_w.setPatientName(subject.name)
-        if subject.gender == "男":
-            self.edf_w.setGender(1)
-        elif subject.gender == "女":
-            self.edf_w.setGender(0)
-        self.edf_w.setBirthdate(subject.birthday)
-        self.edf_w.setPatientCode(subject.id_card)
+        self.edf_w.setPatientName(subject)
         self.edf_w.setRecordingAdditional(str(task_per_run))
         self.is_ready = True
         self.is_first = True

+ 1 - 1
backend/main.py

@@ -3,7 +3,7 @@ from datetime import datetime
 
 import streamlit as st
 
-from db.models import subject
+from db import subject
 from components.remove_style import hide_footer
 
 

+ 3 - 3
backend/pages/2_train.py

@@ -4,8 +4,8 @@ import os
 
 import streamlit as st
 
-from db.models import subject
-from db.models import train
+from db import subject
+from db import train
 from components.remove_style import hide_footer
 import page_utils
 
@@ -18,7 +18,7 @@ def _create_train(conn, subjects):
         trial_num = st.number_input("训练次数", value=10, step=5)
         owner_name = st.selectbox("用户", subjects.name.to_list())
         virtual_feedback_rate = st.number_input("假反馈比例", min_value=0., max_value=1., value=0., step=0.2)
-        model_path = page_utils.file_selector(os.path.join(f'./model/{owner_name}'))
+        model_path = page_utils.file_selector(os.path.join(f'./static/models/{owner_name}'))
         submitted = st.form_submit_button("开始训练")
         if submitted:
             start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")

+ 2 - 1
backend/pages/3_test.py

@@ -2,7 +2,8 @@
 import streamlit as st
 import os
 from datetime import datetime
-from db.models import test, subject
+from db import subject
+from db import test
 from components.remove_style import hide_footer
 
 import page_utils

+ 0 - 74
backend/schemas/hand_peripheral.py

@@ -1,74 +0,0 @@
-"""睿手相关参数模型"""
-from enum import Enum
-
-from pydantic import BaseModel
-from pydantic import Field
-
-
-FINGERMODEL_IDS = {
-    'rest': 0,
-    'cylinder': 1,
-    'ball': 2,
-    'flex': 3,
-    'double': 4,
-    'treble': 5
-}
-
-
-class ChannelName(int, Enum):
-    CHANNEL_A = 0x01
-    CHANNEL_B = 0x02
-
-
-class DraftChannel(int, Enum):
-    SINGLE = 0x01
-    DOUBLE = 0x02
-
-
-class IsElectric(int, Enum):
-    WITH_ELECTRIC = 0x00
-    WITHOUT_ELECTRIC = 0x01
-
-
-class SetCurrent(BaseModel):
-    """set current pydantic model"""
-    channel: ChannelName = Field(
-        ...,
-        description=
-        "set peripheral hand current channel (channelA: 0x01, channelB: 0x02)")
-    value: int = Field(...,
-                       le=255,
-                       ge=0,
-                       description="set peripheral hand current value")
-
-
-class ControlMotion(BaseModel):
-    """control motion pydantic model"""
-    hand_select: str = Field()
-    thumb: int = Field(..., le=100, ge=0)
-    index_finger: int = Field(..., le=100, ge=0)
-    middle_finger: int = Field(..., le=100, ge=0)
-    ring_finger: int = Field(..., le=100, ge=0)
-    little_finger: int = Field(..., le=100, ge=0)
-    duration: int = Field(..., le=20, ge=5)
-
-
-class DraftingAction(BaseModel):
-    """drafting action pydantic model"""
-    hand_select: str = Field(
-        ...,
-        description="select control hand (double: 0x01, left: 0x02, right:0x03)"
-    )
-    is_electric: IsElectric = Field(
-        ..., description="model (with electric: 0x00, without electric: 0x01)")
-    draft_channel: DraftChannel = Field(
-        ..., description="select channel (a channel: 0x01, double: 0x02)")
-    a_channel_value: int = Field(...,
-                                 le=255,
-                                 ge=0,
-                                 description="set A channel hand current value")
-    b_channel_value: int = Field(...,
-                                 le=255,
-                                 ge=0,
-                                 description="set b channel hand current value")
-    duration: int = Field(..., le=20, ge=5)

+ 0 - 86
backend/schemas/subjects.py

@@ -1,86 +0,0 @@
-"""Module schemas/subjects verifies table data type"""
-from datetime import date
-from datetime import datetime
-from typing import List, Literal
-from typing import Optional
-from typing import Union
-
-from pydantic import BaseModel
-from pydantic import Field
-from pydantic import validator
-
-from schemas.trains import ShowTrain
-from settings.config import settings
-
-language = settings.config["lang"]
-message_dict = settings.get_message()
-message = message_dict[language]
-
-
-def get_timestamp() -> datetime:
-    return datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
-                             "%Y-%m-%d %H:%M:%S")
-
-
-class SubjectBase(BaseModel):
-    """Subject Base Pydantic Model"""
-    name: str
-    id_card: Union[str, None]
-    gender: Literal["男", "女"]
-    birthday: date
-    rehabilitation_parts: list
-    remarks: str = ""
-
-    @validator("birthday")
-    def validate_birthday_date(cls, value):
-        value = datetime.strptime(str(value), "%Y-%m-%d")
-        if value > datetime.now() or value < datetime.strptime(
-                "1880-01-01", "%Y-%m-%d"):
-            raise ValueError(message["form_error_gender"])
-        return value
-
-    @validator("rehabilitation_parts")
-    def validate_rehabilitation_parts_date(cls, value):
-        if len(value) == 0 or len(value) > 4:
-            raise ValueError(message["rehab_parts_length"])
-        for part in value:
-            if part not in ["左手", "右手", "左腿", "右腿"]:
-                raise ValueError(message["rehab_parts_value_error"])
-        return value
-
-
-class SubjectUpdate(SubjectBase):
-    pass
-
-
-class SubjectCreate(SubjectBase):
-    create_time: Optional[datetime] = Field(default_factory=get_timestamp)
-
-
-class ShowSubject(BaseModel):
-    """展示患者信息"""
-    id: str
-    name: str
-    id_card: str
-    age: int
-    gender: str
-    rehabilitation_parts: str
-    create_time: datetime
-
-    class Config():
-        orm_mode = True
-
-
-class ShowSubjectDetails(ShowSubject):
-    """展示患者详情"""
-    trains: List[ShowTrain]
-
-
-class TodayStats(BaseModel):
-
-    today_num_format: str
-
-
-class SubjectIds(BaseModel):
-
-    ids: List[str] = Field(...)

+ 0 - 66
backend/schemas/trains.py

@@ -1,66 +0,0 @@
-"""Module schemas/trains verifies table data type"""
-from datetime import datetime
-from typing import Literal, Optional
-from typing import Union
-
-from pydantic import BaseModel, Field
-
-from schemas.hand_peripheral import ControlMotion
-
-
-def get_timestamp() -> datetime:
-    return datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
-                             "%Y-%m-%d %H:%M:%S")
-
-
-class TrainBase(BaseModel):
-    position: Optional[str] = None
-    rank: Optional[str] = None
-    trial_num: Optional[str] = None
-
-
-class TrainUpdate(TrainBase):
-    # position: str
-    # rank: str
-    # trial_num: int
-    start_time: Optional[datetime] = Field(default_factory=get_timestamp)
-    end_time: Optional[datetime] = Field(default_factory=get_timestamp)
-
-
-class TrainCreate(TrainBase):
-    # position: str
-    # rank: str
-    # trial_num: int
-    start_time: Optional[datetime] = Field(default_factory=get_timestamp)
-    end_time: Optional[datetime] = Field(default_factory=get_timestamp)
-    device_param: Union[ControlMotion, None] = None
-    owner_id: str
-
-
-class ShowTrain(TrainBase):
-    position: str
-    rank: str
-    trial_num: int
-    start_time: datetime
-    end_time: datetime
-    grade: str = None
-    consume_time: int = None
-    accuracy: float = None
-    is_train: bool = False
-    medical_certificate: str = ""
-
-    class Config():
-        orm_mode = True
-
-class ShowTrainWithVideo(ShowTrain):
-    video_path: list
-
-
-class TrainResult(BaseModel):
-    grade: Literal["优秀", "良好", "尚可"] = Field()
-    accuracy: float = Field()
-    consume_time: float = Field()
-
-
-class TrainMedicalCertificate(BaseModel):
-    medical_certificate: str

+ 8 - 0
backend/settings/config.py

@@ -28,6 +28,14 @@ class Settings:
             'STIM'
         ]
     }
+    FINGERMODEL_IDS = {
+        'rest': 0,
+        'cylinder': 1,
+        'ball': 2,
+        'flex': 3,
+        'double': 4,
+        'treble': 5
+    }
     PROJECT_VERSION: str = '0.0.1'
     DATA_PATH = './data'
 

+ 0 - 214
backend/static/config/config.json

@@ -1,214 +0,0 @@
-{
-    "lang": "zh",
-    "hospital": "XXX医院",
-    "URL": {
-        "base": "http://localhost:8000",
-        "ws_base": "ws://localhost:8000",
-        "static": "/static",
-        "camera_route": "/api/v1/motion/camera",
-        "camera_set_output": "/api/v1/motion/camera/set-output",
-        "close_camera_route": "/api/v1/motion/close-camera",
-        "eeg_data_read": "/api/v1/eeg/data",
-        "eeg_data_buffer": "/api/v1/eeg/data-buffer",
-        "eeg_device_connect": "/api/v1/eeg/eeg-device-connect",
-        "eeg_device_close": "/api/v1/eeg/eeg-model-close",
-        "eeg_restart_fake_data": "/api/v1/eeg/restart-fake-data",
-        "eeg_clf_reset": "/api/v1/eeg/eeg-clf-reset",
-        "eeg_pipeline_reset": "/api/v1/eeg/eeg-pipeline-reset",
-        "eeg_train_configs": "/api/v1/eeg/train-configs",
-        "initial_rest_state_run": "/api/v1/eeg/initial-rest-state-run",
-        "mi_state_run": "/api/v1/eeg/mi-state-run",
-        "rest_state_run": "/api/v1/eeg/rest-state-run",
-        "mi_test_run": "/api/v1/eeg/mi-test-run",
-        "eeg_edf_set_header": "/api/v1/eeg/eeg-edf-set-header",
-        "eeg_save_data": "/api/v1/eeg/eeg-save-data",
-        "eeg_result_data": "/api/v1/trains/{train_id}/result",
-        "api_train_medical_certificate": "/api/v1/trains/{train_id}/medical-certificate",
-        "impedance_model_connect": "/api/v1/eeg/impedance-model-connect",
-        "impedance_model_close": "/api/v1/eeg/impedance-model-close",
-        "impedance_data": "/api/v1/eeg/impedance-data",
-        "set_train_finish_flag": "/api/v1/eeg/set-train-finish-flag",
-        "get_train_finish_flag": "/api/v1/eeg/get-train-finish-flag",
-        "delete_train": "/api/v1/trains/{train_id}",
-        "raw_bdf_data_close": "/api/v1/eeg/eeg-edf-close",
-        "eeg_edf_mark": "/api/v1/eeg/eeg-edf-mark",
-        "get_today_stats": "/api/v1/subjects/today-stats",
-        "startup_peripheral": "/api/v1/trains/{train_id}/startup-peripheral",
-        "web_subjects": "/subjects",
-        "web_subjects_update": "/subjects/{subject_id}",
-        "web_subjects_details": "/subjects/{subject_id}/details",
-        "api_subjects_delete": "/api/v1/subjects/{subject_id}",
-        "web_trains_start": "/trains/{train_id}/start",
-        "web_trains_test": "/trains/{train_id}/test",
-        "web_trains_details": "/trains/{train_id}/details",
-        "api_subjects_autocomplete": "api/v1/subjects/autocomplete",
-        "api_peripheral_get_serial_ports": "/api/v1/peripheral/serial-ports",
-        "api_peripheral_hand_init": "/api/v1/peripheral/hand/init",
-        "api_peripheral_hand_start": "/api/v1/peripheral/hand/start",
-        "api_peripheral_hand_stop": "/api/v1/peripheral/hand/stop",
-        "api_peripheral_hand_status": "/api/v1/peripheral/hand/status",
-        "api_peripheral_hand_close": "/api/v1/peripheral/hand/close",
-        "api_mi_img_erds": "/api/v1/mi/img/erds",
-        "api_mi_img_csp": "/api/v1/mi/img/csp",
-        "api_mi_img_wpli": "/api/v1/mi/img/wpli",
-        "api_mi_img_psd": "/api/v1/mi/img/psd"
-    },
-    "resource": {
-        "camera_placeholder": "/images/camera_placeholder.png"
-    },
-    "camera": {
-        "id": 0,
-        "task": "record"
-    },
-    "test_parameter": {
-        "rest_decrease_time": 0,
-        "eeg_psd_class": 1,
-        "fake_data": false,
-        "verify": false,
-        "device": "neo"
-    },
-    "faker_eeg_config": {
-        "host": "127.0.0.1",
-        "port": 21112,
-        "channel_count": 24,
-        "sample_rate": 1000,
-        "delay_milliseconds": 100,
-        "buffer_plot_size_seconds": 0.1,
-        "channel_labels": [
-            "T6",
-            "P4",
-            "Pz",
-            "M2",
-            "F8",
-            "F4",
-            "Fp1",
-            "Cz",
-            "M1",
-            "F7",
-            "F3",
-            "C3",
-            "T3",
-            "A1",
-            "Oz",
-            "O1",
-            "O2",
-            "Fz",
-            "C4",
-            "T4",
-            "Fp2",
-            "A2",
-            "T5",
-            "P3"
-        ],
-        "sig_types": [
-            "saw_tooth",
-            "square",
-            "square",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin",
-            "sin"
-        ],
-        "source": "bdf_data/sample.bdf",
-        "signal_generator_config": {
-            "frequency": 2,
-            "wave_height": 100,
-            "saw_tooth_peak_num": 40,
-            "noise": true,
-            "baseline_shift": 0
-        }
-    },
-    "pony_eeg_config": {
-        "device_address": "192.168.1.88",
-        "triggerbox_address": "10.0.0.63",
-        "gain": 12,
-        "channel_count": 24,
-        "sample_rate": 1000,
-        "delay_milliseconds": 100,
-        "buffer_plot_size_seconds": 0.1,
-        "channel_labels": [
-            "T6",
-            "P4",
-            "Pz",
-            "M2",
-            "F8",
-            "F4",
-            "Fp1",
-            "Cz",
-            "M1",
-            "F7",
-            "F3",
-            "C3",
-            "T3",
-            "A1",
-            "Oz",
-            "O1",
-            "O2",
-            "Fz",
-            "C4",
-            "T4",
-            "Fp2",
-            "A2",
-            "T5",
-            "P3"
-        ]
-    },
-    "neo_eeg_config": {
-        "host": "127.0.0.1",
-        "port": 8712,
-        "channel_count": 9,
-        "sample_rate": 1000,
-        "delay_milliseconds": 40,
-        "buffer_plot_size_seconds": 0.04,
-        "channel_labels": [
-            "C3",
-            "FC3",
-            "CP5",
-            "CP1",
-            "C4",
-            "FC4",
-            "CP2",
-            "CP6",
-            "Fp1"
-        ]
-    },
-    "hand_peripheral_parameter": {
-        "hand_host": "192.168.1.1",
-        "hand_port": 21111,
-        "hand_version": [1, 0],
-        "hand_heart": 0.5
-    },
-    "frontend_plot":{
-        "sample_rate": 100,
-        "show_channel": [
-            "C3",
-            "FC3",
-            "CP5",
-            "CP1",
-            "C4",
-            "FC4",
-            "CP2",
-            "CP6",
-            "Fp1"
-        ],
-        "max_time": 10,
-        "update_duration": 50
-    }
-}

+ 0 - 38
backend/static/config/message_zh.json

@@ -1,38 +0,0 @@
-{
-    "update_success": "更新成功",
-    "delete_success": "删除成功",
-    "update_failed": "更新失败",
-    "delete_failed": "删除失败",
-    "create_success": "创建成功",
-    "create_failed": "创建失败",
-    "invalid_age_input": "无效年龄输入",
-    "invalid_gender_input": "无效性别输入",
-    "form_error_name": "请填写姓名",
-    "form_error_id_card": "请填写标识号码",
-    "form_error_gender": "请填写有效性别:男或女",
-    "form_error_age": "请填写有效年龄",
-    "form_error_birthday": "请填写出生年月",
-    "form_error_plan": "请填写康复计划",
-    "subject_id_missing": "用户记录没有找到",
-    "train_id_missing": "训练记录没有找到",
-    "open_camera_failed": "摄像头打开失败",
-    "close_camera_success": "摄像头关闭成功",
-    "name_require": "^请输入姓名",
-    "name_length": "^请输入姓名,中/英/符号,长度30以内",
-    "id_card_require": "^请输入病历号",
-    "id_exclusion": "^该病历号已存在,请重新输入",
-    "gender_require": "^请输入性别",
-    "rehab_parts_length": "^请选择至少一个康复部位",
-    "rehab_parts_value_error": "^请输入有效的康复部位",
-    "birth_require": "^请选择出生日期",
-    "birth_range": "^请确认您的年龄在5~100岁",
-    "ruishou_connect_failed": "睿手连接失败",
-    "ruishou_start_success": "睿手启动成功",
-    "ruishou_no_effect_part": "不是有效部位",
-    "hand_peripheral_not_init": "机械手未初始化",
-    "pneumatic_finger_init_success": "气动手初始化成功",
-    "pneumatic_finger_init_failed": "气动手初始化失败,请检查设备是否已启动并进入镜像模式",
-    "pneumatic_finger_operate_success": "气动手操作成功",
-    "pneumatic_finger_operate_failed": "气动手操作失败,请查看设备是否已启动并进入镜像模式",
-    "pneumatic_finger_close_success": "气动手关闭成功"
-}

+ 0 - 0
backend/model/.gitkeep → backend/static/images/.gitkeep


+ 0 - 0
backend/tests/utils/__init__.py → backend/static/models/.gitkeep


BIN
backend/tests/data/5_3_right_hand.bdf


File diff suppressed because it is too large
+ 0 - 0
backend/tests/data/eeg_raw_data.bdf


BIN
backend/tests/data/neo_eeg_raw_data.bdf


BIN
backend/tests/data/normal_side.mp4


+ 0 - 68
backend/tests/device/peripheral/hand/test_fubo_pneumatic_finger.py

@@ -1,68 +0,0 @@
-'''
-@Author  :   liujunshen
-@File    :   test_fubo_pneumatic_finger.py
-@Time    :   2023/04/04 17:13:01
-富伯气动手测试用例,需要: 1.连接富伯气动手 2.开机 3.进入镜像模式 4.启动 5.获取串口名称并修改
-'''
-
-import time
-
-import pytest
-
-from device.peripheral.hand.fubo_pneumatic_finger import FuboPneumaticFingerClient
-from device.peripheral.hand.fubo_pneumatic_finger import get_serial_ports
-
-PORT = "COM4"
-init_params = {"port": PORT}
-
-
-@pytest.mark.fubo_pneumatic_finger
-def test_get_ports_from_computer_success():
-    ports = get_serial_ports()
-    assert len(ports) > 0
-
-
-@pytest.mark.fubo_pneumatic_finger
-def test_client_init_success():
-    client = FuboPneumaticFingerClient(init_params)
-    ret = client.init()
-    assert ret["is_connected"]
-    client.close()
-
-
-@pytest.mark.fubo_pneumatic_finger
-def test_client_close_success():
-    client = FuboPneumaticFingerClient(init_params)
-    client.init()
-    ret = client.close()
-    assert not ret["is_connected"]
-
-
-@pytest.mark.fubo_pneumatic_finger
-def test_start_flex_success():
-    client = FuboPneumaticFingerClient(init_params)
-    client.init()
-    receive = client.flex()
-    assert len(receive) > 0
-    time.sleep(3)
-    client.close()
-
-
-@pytest.mark.fubo_pneumatic_finger
-def test_start_extend_success():
-    client = FuboPneumaticFingerClient(init_params)
-    client.init()
-    receive = client.extend()
-    assert len(receive) > 0
-    time.sleep(3)
-    client.close()
-
-
-@pytest.mark.fubo_pneumatic_finger
-def test_start_operate_success():
-    client = FuboPneumaticFingerClient(init_params)
-    client.init()
-    receive = client.start()
-    assert len(receive) > 0
-    time.sleep(15)
-    client.close()

+ 0 - 274
backend/tests/device/peripheral/hand/test_ruishou.py

@@ -1,274 +0,0 @@
-"""
-@Author  :   liujunshen
-@File    :   test_ruishou.py
-@Time    :   2023/04/04 13:57:03
-"""
-
-from collections import namedtuple
-import time
-
-import pytest
-
-from device.peripheral.hand.ruishou import Constants
-from device.peripheral.hand.ruishou import Protocol
-from device.peripheral.hand.ruishou import RuishouClient
-from device.peripheral.hand.ruishou import RuishouConnector
-
-buffer_time = 0.3
-ParamStruct = namedtuple(
-    "ParamStruct",
-    "hand_select thumb index_finger middle_finger ring_finger little_finger duration"
-)
-
-
-# ============= 测试解析模块 ================
-def test_protocol_get_pack_success():
-    protocol = Protocol()
-    ret = protocol.get_pck("finish_action")
-    assert isinstance(ret, bytearray)
-
-
-def test_protocol_get_pack_with_fail_cmd_return_none():
-    protocol = Protocol()
-    ret = protocol.get_pck("error_cmd")
-    assert ret is None
-
-
-def test_protocol_get_motion_control_pack_success():
-    params = {
-        Constants.SendPckLocation.MOTION_CONTROL_HAND: 2,
-        Constants.SendPckLocation.MOTION_CONTROL_THUMB_BENDING: 15,
-        Constants.SendPckLocation.MOTION_CONTROL_INDEX_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_MIDDLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_RING_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_LITTLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_DURATION: 10
-    }
-    protocol = Protocol()
-    ret = protocol.get_pck("motion_control", params)
-    assert isinstance(ret, bytearray)
-
-
-def test_protocol_get_motion_control_pack_with_lack_update_dict_return_none():
-    params = {
-        Constants.SendPckLocation.MOTION_CONTROL_HAND: 2,
-        Constants.SendPckLocation.MOTION_CONTROL_THUMB_BENDING: 15,
-        Constants.SendPckLocation.MOTION_CONTROL_RING_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_LITTLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_DURATION: 10,
-    }
-    protocol = Protocol()
-    ret = protocol.get_pck("motion_control", params)
-    assert ret is None
-
-
-def test_protocol_get_motion_control_pack_with_error_update_dict_return_none():
-    params = {
-        11: buffer_time,
-        Constants.SendPckLocation.MOTION_CONTROL_THUMB_BENDING: 15,
-        Constants.SendPckLocation.MOTION_CONTROL_RING_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_LITTLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_DURATION: 10,
-    }
-    protocol = Protocol()
-    ret = protocol.get_pck("motion_control", params)
-    assert ret is None
-
-
-def test_protocol_unpack_one_cmd_bytes_success():
-    protocol = Protocol()
-    b = b"\xae\xaf\x05\x01\x00\x00\xff\xff\x01\xff"
-    ret = protocol.unpack_bytes(b)
-    assert isinstance(ret, list)
-    assert len(ret) == 1
-
-
-def test_protocol_unpack_two_cmd_bytes_success():
-    protocol = Protocol()
-    b = b"\xae\xaf\x05\x01\x00\x00\xff\xff\x01\xff\xae\xaf\x05\x02\xff\xff\xff\xff\x03\xfe"
-    ret = protocol.unpack_bytes(b)
-    assert isinstance(ret, list)
-    assert len(ret) == 2
-
-
-def test_protocol_unpack_bytes_with_error_bytes_return_empty_list():
-    protocol = Protocol()
-    b = b"\x05\x01\x00\x00\xff\xff\x01\xff"
-    ret = protocol.unpack_bytes(b)
-    assert isinstance(ret, list)
-    assert len(ret) == 0
-
-
-def test_protocol_parse_list_success():
-    protocol = Protocol()
-    b = b"\xae\xaf\x05\x01\x00\x00\xff\xff"
-    unpack_data = protocol.unpack_bytes(b)
-    parsed_data = protocol.parse_bytes(unpack_data[0])
-    assert isinstance(parsed_data, dict)
-
-
-# =======以下需要启动设备或模拟测试软件=============
-
-
-@pytest.mark.ruishou
-def test_connector_connect_and_close_success():
-    connector = RuishouConnector()
-    connector.start_client()
-    time.sleep(buffer_time)
-    connector.close_client()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_connector_sync_send_control_motion_data_success():
-    connector = RuishouConnector()
-    connector.start_client()
-    params = {
-        Constants.SendPckLocation.MOTION_CONTROL_HAND: 2,
-        Constants.SendPckLocation.MOTION_CONTROL_THUMB_BENDING: 15,
-        Constants.SendPckLocation.MOTION_CONTROL_INDEX_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_MIDDLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_RING_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_LITTLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_DURATION: 5
-    }
-    res = connector.sync_send_data("motion_control", params)
-    assert isinstance(res, dict)
-    time.sleep(5)
-    connector.close_client()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_connector_sync_send_error_data_return_none():
-    connector = RuishouConnector()
-    connector.start_client()
-    params = {}
-    res = connector.sync_send_data("error_data", params)
-    assert res is None
-    time.sleep(buffer_time)
-    connector.close_client()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_connector_stop_operate_success():
-    connector = RuishouConnector()
-    connector.start_client()
-    params = {
-        Constants.SendPckLocation.MOTION_CONTROL_HAND: 1,
-        Constants.SendPckLocation.MOTION_CONTROL_THUMB_BENDING: 15,
-        Constants.SendPckLocation.MOTION_CONTROL_INDEX_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_MIDDLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_RING_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_LITTLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_DURATION: 5
-    }
-    connector.sync_send_data("motion_control", params)
-    time.sleep(3)
-    connector.sync_send_data("finish_action")
-    connector.close_client()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_connector_sync_send_control_motion_error_params_fail():
-    connector = RuishouConnector()
-    connector.start_client()
-    params = {
-        Constants.SendPckLocation.MOTION_CONTROL_HAND: 2,
-        Constants.SendPckLocation.MOTION_CONTROL_THUMB_BENDING: 15,
-        Constants.SendPckLocation.MOTION_CONTROL_RING_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_LITTLE_FINGER_BENDING: 10,
-        Constants.SendPckLocation.MOTION_CONTROL_DURATION: 5
-    }
-    res = connector.sync_send_data("motion_control", params)
-    assert res is None
-    time.sleep(buffer_time)
-    connector.close_client()
-    time.sleep(buffer_time)
-
-
-# ========== 测试睿手客户端(对业务) ============
-
-
-@pytest.mark.ruishou
-def test_client_init_success():
-    client = RuishouClient()
-    ret = client.init()
-    assert ret["is_connected"]
-    time.sleep(buffer_time)
-    client.close()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_client_get_status_success():
-    client = RuishouClient()
-    client.init()
-    status = client.status()
-    assert status["is_connected"]
-    time.sleep(buffer_time)
-    client.close()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_client_get_status_with_not_init_return_not_connected():
-    client = RuishouClient()
-    status = client.status()
-    assert not status["is_connected"]
-
-
-@pytest.mark.ruishou
-def test_client_get_status_with_close_client_return_not_connected():
-    client = RuishouClient()
-    client.init()
-    time.sleep(buffer_time)
-    client.close()
-    status = client.status()
-    assert not status["is_connected"]
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_client_start_operate_success():
-    client = RuishouClient()
-    client.init()
-    params = ParamStruct("左手", 10, 10, 10, 10, 10, 6)
-    res = client._control_motion(params)
-    assert isinstance(res, dict)
-    time.sleep(5)
-    client.close()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_client_reconnect_start_operate_success():
-    client = RuishouClient()
-    client.init()
-    params = ParamStruct("左手", 10, 10, 10, 10, 10, 6)
-    time.sleep(buffer_time)
-    client.close()
-    res = client._control_motion(params)
-    assert isinstance(res, dict)
-    time.sleep(5)
-    client.close()
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_client_close_success():
-    client = RuishouClient()
-    client.init()
-    time.sleep(buffer_time)
-    client.close()
-    assert not client.connector.is_connected
-    time.sleep(buffer_time)
-
-
-@pytest.mark.ruishou
-def test_client_close_with_not_init_success():
-    client = RuishouClient()
-    client.close()
-    assert not client.connector.is_connected

+ 0 - 186
backend/tests/device/sig_chain/device/test_neo.py

@@ -1,186 +0,0 @@
-"""Module tests/core/sig_chain/device/test_neo provide test for neo connector"""
-import pytest
-import struct
-import unittest
-from unittest.mock import MagicMock
-from unittest.mock import patch
-
-import numpy as np
-
-from device.sig_chain.device.neo import bytes_to_float32
-from device.sig_chain.device.neo import NeoConnector
-
-
-TASK_PER_RUN = 1
-
-
-def teardown_function():
-    NeoConnector.clear_instance()
-
-
-def gen_fake_recv_data(data_count_per_channel, channel_count):
-    # 假的接收数据
-    recv_data = np.ones((data_count_per_channel, channel_count),
-                        dtype=np.float32)
-    for ii in range(channel_count):
-        recv_data[:, ii] = (ii + 1) * recv_data[:, ii]
-    return recv_data
-
-
-# ===================
-
-
-def test_new_connector_is_disconnected():
-    connector = NeoConnector()
-    assert not connector.is_connected()
-
-
-def test_new_connector_receive_wave_failed():
-    connector = NeoConnector()
-    with pytest.raises(Exception):
-        connector.receive_wave()
-
-
-def test_after_get_ready_is_connected():
-    connector = NeoConnector()
-
-    mock_socket = MagicMock()
-    mock_socket.connect.return_value = True
-    with patch('socket.socket', mock_socket):
-        success = connector.get_ready()
-        assert success
-    assert connector.is_connected()
-
-
-@unittest.skip('未实现')
-def test_after_get_ready_skip_connect_request():
-    connector = NeoConnector()
-    connector.get_ready()
-    success = connector.get_ready()
-    assert success
-
-
-def test_after_connected_receive_wave_success():
-    connector = NeoConnector()
-
-    recv_data = gen_fake_recv_data(
-        connector.sample_params.data_count_per_channel,
-        connector.sample_params.channel_count)
-
-    def side_effect(arg): # 用于确认接收到的参数
-        assert (arg == recv_data.T).all()
-    connector._add_a_data_block_to_buffer = MagicMock(side_effect=side_effect)
-
-    mock_socket = MagicMock()
-    mock_socket.connect.return_value = True
-    mock_socket.recv.return_value = recv_data.tobytes() #b''
-    connector._sock = mock_socket
-
-    success = connector.receive_wave()
-    assert success
-
-
-def test_after_stop_is_disconnected():
-    connector = NeoConnector()
-    mock_socket = MagicMock()
-    mock_socket.connect.return_value = True
-    mock_socket.close = MagicMock()
-    with patch('socket.socket', mock_socket):
-        connector.get_ready()
-        connector.stop()
-    assert not connector.is_connected()
-
-
-def test_load_partial_config_success():
-    connector = NeoConnector()
-    mock_config = {
-        'host': '1.0.0.1'
-    }
-    connector.load_config(mock_config)
-    assert connector._host == mock_config['host']
-
-
-def test_after_set_saver_buffer_is_set():
-    connector = NeoConnector()
-    connector.set_saver()
-
-    assert connector.buffer_save is not None
-
-
-def test_before_set_edf_header_save_data_not_called():
-    connector = NeoConnector()
-    connector.set_saver()
-
-    mock_save_raw_data = MagicMock()
-    connector.saver.save_raw_data = mock_save_raw_data
-
-    recv_data = gen_fake_recv_data(
-        connector.sample_params.data_count_per_channel,
-        connector.sample_params.channel_count)
-
-    mock_socket = MagicMock()
-    mock_socket.connect.return_value = True
-    mock_socket.recv.return_value = recv_data.tobytes()
-    connector._sock = mock_socket
-    connector.receive_wave()
-    assert not mock_save_raw_data.called
-
-
-def test_after_receive_wave_observers_are_notified():
-    connector = NeoConnector()
-
-    recv_data = gen_fake_recv_data(
-        connector.sample_params.data_count_per_channel,
-        connector.sample_params.channel_count)
-
-    mock_socket = MagicMock()
-    mock_socket.connect.return_value = True
-    mock_socket.recv.return_value = recv_data.tobytes()
-    connector._sock = mock_socket
-    connector._save_data_when_buffer_full = MagicMock()
-
-    mock_notify_observers = MagicMock()
-    connector.notify_observers = mock_notify_observers
-
-    connector.receive_wave()
-    assert mock_notify_observers.called
-
-
-def test_with_matched_packet_bytes_to_float32_success():
-    expected = [12.0, 0.0, -12398.1982421875, 34567.98828125]
-    packet = b''
-    for value in expected:
-        packet += struct.pack('f', value)
-
-    result = bytes_to_float32(packet, len(packet), 4)
-
-    assert expected == result
-
-
-def test_mismatched_packet_bytes_to_float32_failed():
-    expected = [12.0, 0.0, -12398.1982421875, 34567.98828125]
-    packet = b''
-    for value in expected:
-        packet += struct.pack('f', value)
-    packet = packet[:-2]
-
-    with pytest.raises(AssertionError):
-        bytes_to_float32(packet, len(packet), 4)
-
-
-def test_main():
-    # pylint: disable=import-outside-toplevel
-    from schemas.subjects import SubjectCreate
-    # pylint: enable=import-outside-toplevel
-    connector = NeoConnector()
-    connector.set_saver()
-    subject = SubjectCreate(name='nobody',
-                            id_card='12345',
-                            gender='男',
-                            birthday='1988-01-01',
-                            rehabilitation_parts=['左手'])
-    connector.saver.set_edf_header(subject, 'filename.bdf', TASK_PER_RUN, '.')
-    if connector.get_ready():
-        for _ in range(20):
-            connector.receive_wave()
-    connector.stop()

+ 0 - 223
backend/tests/device/sig_chain/test_receive.py

@@ -1,223 +0,0 @@
-"""Module tests/core/sig_chain/test_receive provide test for receiver"""
-import pytest
-import time
-import unittest
-from unittest import mock
-
-from func_timeout import FunctionTimedOut
-
-from device.sig_chain.device.connector_interface import DataMode
-from device.sig_chain.device.connector_interface import Device
-from device.sig_chain.sig_receive import Receiver
-
-
-
-def teardown_function():
-    Receiver.clear_instance()
-
-
-def test_new_receiver_is_not_ready():
-    receiver = Receiver()
-    assert not receiver.is_ready
-
-
-def test_before_select_can_not_setup_connector():
-    receiver = Receiver()
-    with pytest.raises(AssertionError):
-        receiver.setup_connector()
-
-
-def test_before_setup_connector_receive_data_failed():
-    receiver = Receiver()
-    with pytest.raises(AssertionError):
-        receiver.start_receive_wave()
-
-
-def test_before_setup_connector_get_data_from_buffer_failed():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.buffer_plot.get_sig = \
-        mock.MagicMock(return_value={'status': 'ok'})
-    with pytest.raises(RuntimeError):
-        receiver.get_data_from_buffer('plot')
-
-    receiver.buffer_classify_online.get_sig = \
-        mock.MagicMock(return_value={'status': 'ok'})
-    with pytest.raises(RuntimeError):
-        receiver.get_data_from_buffer('classify_online')
-
-
-def test_before_setup_connector_stop_receive_pass():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-    receiver.stop_receive()
-
-
-def test_after_setup_connector_is_ready():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.connector.get_ready = mock.MagicMock(return_value=True)
-    receiver.setup_connector()
-    assert receiver.is_ready
-
-
-def test_after_setup_wave_receive_mode_is_ready():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.connector.setup_wave_mode = mock.MagicMock(return_value=True)
-    receiver.setup_receive_mode(DataMode.WAVE)
-    assert receiver.is_ready
-
-
-def test_after_setup_impedance_receive_mode_is_ready():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.connector.setup_impedance_mode = \
-        mock.MagicMock(return_value=True)
-    receiver.setup_receive_mode(DataMode.IMPEDANCE)
-    assert receiver.is_ready
-
-
-def test_failed_setup_wave_receive_mode_is_not_ready():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.connector.setup_wave_mode = mock.MagicMock(return_value=False)
-    receiver.setup_receive_mode(DataMode.WAVE)
-    assert not receiver.is_ready
-
-
-def test_failed_setup_impedance_receive_mode_is_not_ready():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.connector.setup_impedance_mode = mock.MagicMock(return_value=False)
-    receiver.setup_receive_mode(DataMode.IMPEDANCE)
-    assert not receiver.is_ready
-
-
-def test_before_ready_stop_receive_pass():
-    receiver = Receiver()
-    receiver.stop_receive()
-
-
-def test_after_stop_receive_is_not_ready():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.connector.get_ready = mock.MagicMock(return_value=True)
-    receiver.setup_connector()
-    receiver.connector.stop = mock.MagicMock(return_value=True)
-    receiver.stop_receive()
-    assert not receiver.is_ready
-
-
-def test_receiver_singleton_keep_status():
-    receiver = Receiver()
-    receiver.is_ready = True
-    receiver = Receiver()
-    assert receiver.is_ready
-
-
-def test_change_wave_to_impedance_mode_success():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-
-    receiver.clear_all_buffer = mock.MagicMock()
-    receiver.connector.get_ready = mock.MagicMock(return_value=True)
-    receiver.setup_connector()
-
-    receiver.connector.stop = mock.MagicMock(return_value=True)
-    receiver.stop_receive()
-
-    receiver.connector.setup_impedance_mode = mock.MagicMock(return_value=True)
-    success = receiver.setup_receive_mode(DataMode.IMPEDANCE)
-    assert success
-
-
-def test_after_setup_connector_buffers_are_cleared():
-    receiver = Receiver()
-
-    mock_connector = mock.MagicMock()
-    mock_connector.get_ready.return_value = True
-    receiver.connector = mock_connector
-
-    mock_clear_all_buffer = mock.MagicMock()
-    receiver.clear_all_buffer = mock_clear_all_buffer
-
-    receiver.setup_connector()
-    assert mock_clear_all_buffer.called
-
-
-def test_after_reset_receive_mode_buffers_are_cleared():
-    receiver = Receiver()
-
-    mock_connector = mock.MagicMock()
-    mock_connector.setup_wave_mode = mock.MagicMock()
-    receiver.connector = mock_connector
-
-    mock_clear_all_buffer = mock.MagicMock()
-    receiver.clear_all_buffer = mock_clear_all_buffer
-
-    receiver.setup_receive_mode(DataMode.WAVE)
-    assert mock_clear_all_buffer.called
-
-
-def test_get_data_from_invalid_buffer_type_failed():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-    receiver.is_ready = True
-    with pytest.raises(AssertionError):
-        receiver.get_data_from_buffer('xxx')
-
-
-def test_get_data_from_buffer_success():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-    receiver.is_ready = True
-
-    mock_data = {'status': 'ok', 'data': 1}
-    receiver.buffer_plot.get_sig = \
-        mock.MagicMock(return_value=mock_data)
-
-    ret = receiver.get_data_from_buffer('plot')
-    assert ret == mock_data
-
-
-@unittest.skip('加入timeout机制会导致卡顿,因此删除此功能')
-def test_limit_time_to_get_data_from_buffer():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-    receiver.is_ready = True
-
-    receiver.buffer_plot.get_sig = \
-        mock.MagicMock(return_value={'status': 'warn'})
-
-    with pytest.raises(FunctionTimedOut):
-        receiver.get_data_from_buffer('plot')
-
-
-@unittest.skip('依赖硬件')
-def test_main():
-    receiver = Receiver()
-    receiver.select_connector(Device.NEO, 1)
-    if receiver.setup_connector():
-        receiver.start_receive_wave()
-    for _ in range(20):
-        time.sleep(1)
-        data_from_buffer = receiver.get_data_from_buffer('plot')
-        if data_from_buffer:
-            raw_data = data_from_buffer['data']
-            print(raw_data)
-    receiver.stop_receive()
-
-    receiver.setup_receive_mode(DataMode.IMPEDANCE)
-    for _ in range(20):
-        time.sleep(1)
-        impedance = receiver.receive_impedance()
-        if impedance:
-            print(impedance)

+ 0 - 143
backend/tests/device/sig_chain/test_sig_buffer.py

@@ -1,143 +0,0 @@
-"""单元测试sig_buffer"""
-import collections
-
-import numpy as np
-
-from device.sig_chain.sig_buffer import CircularBuffer
-from device.sig_chain.sig_buffer import ParserNewsetWithTime
-from device.sig_chain.sig_buffer import PaserNewsetWithoutTime
-
-TimeStamp = collections.namedtuple("Time", ["timestamp", "data"])
-data_len = 10
-package_len = 0.1
-sig_len = 0.5
-chan_labels = ["C3", "C4", "O1", "O2", "Oz"]
-chan_types = ["eeg"] * len(chan_labels)
-fs = 1000
-sig_mock = np.random.rand(len(chan_labels), int(package_len * fs))
-sig_mock_time = TimeStamp(2023, sig_mock)
-
-
-def test_update_is_success():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          PaserNewsetWithoutTime())
-    ring_len = len(ring.content)
-    ring.update(sig_mock)
-    assert len(ring.content) == ring_len + 1
-
-
-def test_ring_is_full():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          PaserNewsetWithoutTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock)
-    assert len(ring.content) == data_len / package_len
-
-
-def test_ring_get_sig():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          PaserNewsetWithoutTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock)
-    _ = ring.get_sig()
-    assert len(ring.content) == 0
-
-
-def test_enough_ring_get_data_status_is_ok():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          PaserNewsetWithoutTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock)
-    data_get = ring.get_sig()
-    status, data = data_get.values()
-    assert data.get_data().shape == (len(chan_labels), ring.data_len * fs)
-    assert status == "ok"
-
-
-def test_not_enough_ring_get_data_status_is_warn():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          PaserNewsetWithoutTime())
-    ring.update(sig_mock)
-    data_get = ring.get_sig()
-    status, data = data_get.values()
-    assert data is None
-    assert status == "warn"
-
-
-def test_update_is_success_time():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          ParserNewsetWithTime())
-    ring_len = len(ring.content)
-    ring.update(sig_mock_time)
-    assert len(ring.content) == ring_len + 1
-
-
-def test_ring_is_full_time():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          ParserNewsetWithTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock_time)
-    assert len(ring.content) == data_len / package_len
-
-
-def test_ring_get_sig_time():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          ParserNewsetWithTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock_time)
-    _ = ring.get_sig()
-    assert len(ring.content) == 0
-
-
-def test_enough_ring_get_data_status_is_ok_time():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          ParserNewsetWithTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock_time)
-    data_get = ring.get_sig()
-    status = data_get["status"]
-    data = data_get["data"]
-    my_time_stamp = data_get["timestamp"]
-    assert data.get_data().shape == (len(chan_labels), ring.data_len * fs)
-    assert status == "ok"
-    assert my_time_stamp[0] == 2023
-
-
-def test_not_enough_ring_get_data_status_is_warn_time():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          ParserNewsetWithTime())
-    ring.update(sig_mock_time)
-    data_get = ring.get_sig()
-    status = data_get["status"]
-    data = data_get["data"]
-    my_time_stamp = data_get["timestamp"]
-    assert data is None
-    assert status == "warn"
-    assert my_time_stamp is None
-
-
-def test_get_sig_with_clear():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                            ParserNewsetWithTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock_time)
-    ring.get_sig(clear=True)
-    assert len(ring.content) == 0
-
-
-def test_get_sig_without_clear():
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                            ParserNewsetWithTime())
-    for _ in range(0, int(data_len / package_len)):
-        ring.update(sig_mock_time)
-    ring.get_sig(clear=False)
-    assert len(ring.content) == len(ring.content)
-
-
-def test_not_enough_ring_get_sig_without_clear():
-
-    ring = CircularBuffer(data_len, package_len, chan_labels, chan_types, fs,
-                          ParserNewsetWithTime())
-    ring.update(sig_mock_time)
-    ring.get_sig(clear=False)
-    assert len(ring.content) == 1

+ 0 - 33
backend/tests/device/sig_chain/test_sig_reader.py

@@ -1,33 +0,0 @@
-"""单元测试 sig_reader"""
-import collections
-import os
-
-from device.sig_chain.sig_reader import Reader
-
-TEST_DATA_PATH = "tests/data/"
-BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "5_3_right_hand.bdf")
-
-
-def test_read():
-    ch_names = [
-        "Fz", "Fp1", "F3", "F7", "C3", "T3", "T5", "P3", "O1", "Cz", "Oz", "Pz",
-        "O2", "P4", "T6", "T4", "C4", "F8", "F4", "Fp2"
-    ]
-    reader = Reader()
-    raw = reader.read(BDF_FILE_PATH, tuple(ch_names))
-    assert (20, 386000) == raw.get_data().shape
-
-
-def test_fix_annotation():
-    ch_names = [
-        "Fz", "Fp1", "F3", "F7", "C3", "T3", "T5", "P3", "O1", "Cz", "Oz", "Pz",
-        "O2", "P4", "T6", "T4", "C4", "F8", "F4", "Fp2"
-    ]
-    reader = Reader()
-    raw = reader.read(BDF_FILE_PATH, tuple(ch_names))
-    reader.fix_annotation(raw)
-
-    ret = collections.Counter(raw.annotations.description)
-    assert 1 == ret["initialRest"]
-    assert 15 == ret["mi"]
-    assert 15 == ret["rest"]

+ 0 - 108
backend/tests/device/sig_chain/test_sig_save.py

@@ -1,108 +0,0 @@
-"""单元测试 sig_save"""
-import os
-
-import numpy as np
-
-from device.sig_chain.sig_save import SigSave
-from schemas.subjects import SubjectCreate
-
-channel_labels = [
-    'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3', 'C3',
-    'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2', 'T5', 'P3'
-]
-test_data_path = './tests/core/sig_chain/test_data/'
-
-filename = 'testfilename.bdf'
-
-TASK_PER_RUN = 1
-
-
-def setup_module():
-    if not os.path.exists(test_data_path):
-        os.makedirs(test_data_path)
-
-
-def teardown_module():
-    os.removedirs(test_data_path)
-
-
-def create_subject():
-    return SubjectCreate(name='nobody',
-                         id_card='12345',
-                         gender='男',
-                         birthday='1988-01-01',
-                         rehabilitation_parts=['左手'])
-
-
-def test_subject_set_edf_header_success():
-    saver = SigSave(channel_labels, 1000, 375000, -375000)
-    subject = create_subject()
-    saver.set_edf_header(subject, filename, TASK_PER_RUN, test_data_path)
-    assert saver.is_ready is True
-
-
-def test_close_edf_file():
-    saver = SigSave(channel_labels, 1000, 375000, -375000)
-    subject = create_subject()
-    saver.set_edf_header(subject, filename, TASK_PER_RUN, test_data_path)
-    saver.close_edf_file()
-    assert saver.is_ready is False
-
-
-def test_save_raw_data_once():
-    saver = SigSave(channel_labels, 1000, 375000, -375000)
-    subject = create_subject()
-    saver.set_edf_header(subject, filename, TASK_PER_RUN, test_data_path)
-    data = np.ones((len(channel_labels), 1000))
-    saver.save_raw_data(data)
-    file_path = test_data_path + filename
-    assert os.path.exists(file_path)
-    assert os.path.getsize(file_path) > 0
-    saver.close_edf_file()
-    os.remove(file_path)
-
-
-def test_save_raw_data_10_times():
-    saver = SigSave(channel_labels, 1000, 375000, -375000)
-    subject = create_subject()
-    saver.set_edf_header(subject, filename, TASK_PER_RUN, test_data_path)
-    data = np.ones((len(channel_labels), 1000))
-    for _ in range(10):
-        saver.save_raw_data(data)
-    file_path = test_data_path + filename
-    assert os.path.exists(file_path)
-    assert os.path.getsize(file_path) > 0
-    saver.close_edf_file()
-    os.remove(file_path)
-
-
-def test_save_raw_data_without_set_header():
-    saver = SigSave(channel_labels, 1000, 375000, -375000)
-    data = np.ones((len(channel_labels), 1000))
-    saver.save_raw_data(data)
-    file_path = test_data_path + filename
-    assert os.path.exists(file_path) is False
-
-
-def test_edf_data_mark():
-    saver = SigSave(channel_labels, 1000, 375000, -375000)
-    subject = create_subject()
-    saver.set_edf_header(subject, filename, TASK_PER_RUN, test_data_path)
-    data = np.ones((len(channel_labels), 1000))
-    saver.save_raw_data(data, 500)
-    file_path = test_data_path + filename
-    saver.edf_data_mark(550, 'OK')
-    saver.close_edf_file()
-    os.remove(file_path)
-
-
-def test_edf_data_mark_timestamp_none():
-    saver = SigSave(channel_labels, 1000, 375000, -375000)
-    subject = create_subject()
-    saver.set_edf_header(subject, filename, TASK_PER_RUN, test_data_path)
-    data = np.ones((len(channel_labels), 1000))
-    saver.save_raw_data(data)
-    file_path = test_data_path + filename
-    saver.edf_data_mark(0.5, 'OK')
-    saver.close_edf_file()
-    os.remove(file_path)

+ 0 - 264
backend/tests/device/test_utils.py

@@ -1,264 +0,0 @@
-"""Test for video analyser """
-import os
-import time
-from unittest.mock import patch
-from unittest.mock import MagicMock
-
-import cv2
-import numpy as np
-
-from device.utils import VideoAnalyser
-
-
-TEST_DATA_PATH = 'tests/data/'
-INPUT_VIDEO_PATH = os.path.join(TEST_DATA_PATH, 'normal_side.mp4')
-OUTPUT_VIDEO_PATH = os.path.join(TEST_DATA_PATH, 'test_base.mp4')
-
-
-def setup_module():
-    if not os.path.exists(TEST_DATA_PATH):
-        os.makedirs(TEST_DATA_PATH)
-
-
-def teardown_function():
-    if os.path.exists(OUTPUT_VIDEO_PATH):
-        os.remove(OUTPUT_VIDEO_PATH)
-
-
-def gen_fake_image():
-    return np.zeros((640, 320, 3), dtype=np.uint8)
-
-
-class TestVideoAnalyser:
-    def test_init_without_input_video_is_camera(self):
-        analyser = VideoAnalyser(input_video=None)
-
-        assert analyser.is_camera
-
-    def test_init_with_input_video_is_not_camera(self):
-        mock_video_capture = MagicMock()
-        mock_video_capture.release = MagicMock()
-        with patch('cv2.VideoCapture', mock_video_capture):
-            analyser = VideoAnalyser(input_video=INPUT_VIDEO_PATH)
-
-        assert not analyser.is_camera
-
-    def test_close_with_opencv_release_resource(self):
-        analyser = VideoAnalyser()
-        analyser.set_output_video(output_video='output.mp4', save_with_av=False)
-
-        analyser.close()
-
-        assert not analyser.cap.isOpened()
-        assert not analyser.out_stream
-
-    def test_close_with_av_release_resource(self):
-        mock_release_container = MagicMock()
-        analyser = VideoAnalyser()
-        analyser.set_output_video(output_video='output.mp4', save_with_av=True)
-        analyser.release_container = mock_release_container
-
-        analyser.close()
-
-        assert not analyser.cap.isOpened()
-        assert not analyser.container
-        assert mock_release_container.called
-
-    def test_set_output_video_with_camera_and_opencv_success(self):
-        analyser = VideoAnalyser(camera_id=0)
-        analyser.open_camera()
-        analyser.set_output_video(output_video='output.mp4', save_with_av=False)
-
-        assert analyser.out_stream
-
-    def test_set_output_video_with_camera_and_av_success(self):
-        analyser = VideoAnalyser(camera_id=0)
-        analyser.open_camera()
-        analyser.set_output_video(output_video='output.mp4', save_with_av=True)
-
-        assert analyser.stream
-        assert analyser.container
-
-    def test_set_output_video_with_video_and_opencv_success(self):
-        analyser = VideoAnalyser(input_video=INPUT_VIDEO_PATH)
-        analyser.set_output_video(output_video='output.mp4', save_with_av=False)
-
-        assert analyser.out_stream
-
-    def test_set_output_video_with_video_and_av_success(self):
-        analyser = VideoAnalyser()
-        analyser.set_output_video(output_video='output.mp4', save_with_av=True)
-
-        assert analyser.stream
-        assert analyser.container
-
-    def test_is_ok_before_cap_open_return_false(self):
-        with patch('cv2.VideoCapture') as mock_cap:
-            mock_cap_instance = mock_cap.return_value
-            mock_cap_instance.isOpened.return_value = False
-            analyser = VideoAnalyser()
-
-            assert not analyser.is_ok()
-
-    def test_is_ok_after_cap_open_return_true(self):
-        with patch('cv2.VideoCapture') as mock_cap:
-            mock_cap_instance = mock_cap.return_value
-            mock_cap_instance.isOpened.return_value = True
-            analyser = VideoAnalyser()
-
-            assert analyser.is_ok()
-
-    def test_process_with_save_when_read_success_save_video(self):
-        with patch('cv2.VideoCapture') as mock_cap:
-            mock_cap_instance = mock_cap.return_value
-            mock_cap_instance.read.return_value = (True, gen_fake_image())
-            mock_save_video = MagicMock()
-
-            analyser = VideoAnalyser()
-            analyser.save_video = mock_save_video
-            analyser.process(save=True)
-
-            assert mock_save_video.called
-
-    def test_process_with_save_when_read_failed_not_save_video(self):
-        with patch('cv2.VideoCapture') as mock_cap:
-            mock_cap_instance = mock_cap.return_value
-            mock_cap_instance.read.return_value = (False, None)
-            mock_save_video = MagicMock()
-
-            analyser = VideoAnalyser()
-            analyser.save_video = mock_save_video
-            analyser.process(save=True)
-
-            assert not mock_save_video.called
-
-    def test_process_without_save_when_read_success_not_save_video(self):
-        with patch('cv2.VideoCapture') as mock_cap:
-            mock_cap_instance = mock_cap.return_value
-            mock_cap_instance.read.return_value = (False, None)
-            mock_save_video = MagicMock()
-
-            analyser = VideoAnalyser()
-            analyser.save_video = mock_save_video
-            analyser.process(save=False)
-
-            assert not mock_save_video.called
-
-    def test_save_video_with_av_av_function_called(self):
-        mock_save_video_with_av = MagicMock()
-
-        analyser = VideoAnalyser()
-        analyser.save_video_with_av = mock_save_video_with_av
-        analyser.save_with_av = True
-        analyser.save_video(gen_fake_image(), time.time())
-
-        assert mock_save_video_with_av.called
-
-    def test_save_video_without_av_opencv_function_called(self):
-        mock_save_video_with_opencv = MagicMock()
-
-        analyser = VideoAnalyser()
-        analyser.save_video_with_opencv = mock_save_video_with_opencv
-        analyser.save_with_av = False
-        analyser.save_video(gen_fake_image(), None)
-
-        assert mock_save_video_with_opencv.called
-
-    def test_save_video_with_opencv_before_set_output_video_pass(self):
-        analyser = VideoAnalyser()
-        analyser.save_video_with_opencv(gen_fake_image())
-
-    def test_save_video_with_opencv_with_out_stream_log_error(self):
-        mock_out_stream = MagicMock()
-        mock_out_stream.isOpened.return_value = False
-        analyser = VideoAnalyser()
-        analyser.out_stream = mock_out_stream
-
-        mock_logging_error = MagicMock()
-        with patch('core.utils.logger.error', mock_logging_error):
-            analyser.save_video_with_opencv(gen_fake_image())
-            assert mock_logging_error.called
-
-    def test_save_video_with_av_before_set_output_video_pass(self):
-        analyser = VideoAnalyser()
-        analyser.save_video_with_av(gen_fake_image(), time.time())
-
-    def test_save_video_with_av_with_antecedent_frame_skipped(self):
-        analyser = VideoAnalyser()
-        analyser.container = MagicMock()
-        analyser.stream = MagicMock()
-        mock_encode = MagicMock()
-        analyser.stream.encode = mock_encode
-
-        t_start = time.time()
-        analyser.t_start_save_video = t_start + 12
-        analyser.save_video_with_av(gen_fake_image(), t_start)
-
-        assert not mock_encode.called
-
-    def test_save_video_with_av_with_valid_frame_success(self):
-        analyser = VideoAnalyser()
-        # analyser = VideoAnalyser(input_video=INPUT_VIDEO_PATH)
-        # analyser.set_output_video(output_video='output.mp4', save_with_av=True)
-        analyser.container = MagicMock()
-        analyser.container.mux = MagicMock()
-        analyser.stream = MagicMock()
-        mock_encode = MagicMock()
-        analyser.stream.encode = mock_encode
-
-        t_start = time.time()
-        analyser.t_start_save_video = t_start - 1
-        analyser.save_video_with_av(gen_fake_image(), t_start)
-
-        assert mock_encode.called
-
-    def test_release_container_without_save_not_finish_with_a_blank_frame(self):
-        mock_av_finish_with_a_blank_frame = MagicMock()
-        analyser = VideoAnalyser()
-        analyser.av_finish_with_a_blank_frame = \
-            mock_av_finish_with_a_blank_frame
-
-        analyser.set_output_video(output_video='output.mp4', save_with_av=True)
-        analyser.release_container()
-
-        assert not mock_av_finish_with_a_blank_frame.called
-
-    def test_release_container_reset_time_start_save_video(self):
-        mock_av_finish_with_a_blank_frame = MagicMock()
-        analyser = VideoAnalyser()
-        analyser.av_finish_with_a_blank_frame = \
-            mock_av_finish_with_a_blank_frame
-        analyser.t_start_save_video = time.time()
-
-        analyser.release_container()
-
-        assert not analyser.t_start_save_video
-
-    def test_release_container_reset_previous_pts(self):
-        mock_av_finish_with_a_blank_frame = MagicMock()
-        analyser = VideoAnalyser()
-        analyser.av_finish_with_a_blank_frame = \
-            mock_av_finish_with_a_blank_frame
-        analyser.previous_pts = 134
-
-        analyser.release_container()
-
-        assert 0 == analyser.previous_pts
-
-
-def test_main():
-    analyser = VideoAnalyser()
-    # analyser.set_output_video(output_video=OUTPUT_VIDEO_PATH, save_with_av=True)
-    analyser.set_output_video(output_video=OUTPUT_VIDEO_PATH)
-    count = 0
-    while analyser.is_ok():
-        count += 1
-        if count == 196:
-            break
-
-        _, image = analyser.process()
-        cv2.imshow('base', image)
-        if cv2.waitKey(1) & 0xFF == ord('q'): #press q to quit
-            break
-
-    cv2.destroyAllWindows()

+ 0 - 36
backend/tests/utils/core.py

@@ -1,36 +0,0 @@
-"""
-Author: linxiaohong linxiaohong@neuracle.cn
-Date: 2023-07-17 14:14:20
-LastEditors: linxiaohong linxiaohong@neuracle.cn
-LastEditTime: 2023-07-19 14:02:01
-FilePath: Albatross/backend/tests/utils/core.py
-Description: tests/core 中的测试共用的工具函数
-
-Copyright (c) 2023 by Neuracle, All Rights Reserved.
-"""
-import mne
-
-
-def get_epochs(raw, picks, event_name=None, tmin=0, tmax=1):
-    events, event_id = mne.events_from_annotations(raw)
-    if event_name is None:
-        event_id_pick = event_id
-    else:
-        event_id_pick = {event_name: event_id[event_name]}
-    epochs = mne.Epochs(raw,
-                        events,
-                        event_id_pick,
-                        tmin,
-                        tmax,
-                        picks=picks,
-                        baseline=None,
-                        preload=True)
-    return epochs
-
-
-def crop_by_annotation(raw, annot):
-    onset = annot["onset"] - raw.first_time
-    if -raw.info["sfreq"] / 2 < onset < 0:
-        onset = 0
-    raw_crop = raw.copy().crop(onset, onset + annot["duration"])
-    return raw_crop

+ 0 - 57
backend/tests/utils/subject.py

@@ -1,57 +0,0 @@
-"""testing subjects utils"""
-import random
-
-from sqlalchemy.orm import Session
-
-from db.repository import subjects as db_rep_sub
-from schemas.subjects import SubjectCreate
-from utils.utils import fake
-
-
-def generate_subject_fake_data():
-    return {
-        "name":
-            fake.name(),
-        "id_card":
-            None,
-        "birthday":
-            str(fake.date_between_dates(date_start="-100y", date_end="-5y")),
-        "gender":
-            fake.subject_gender(),
-        "rehabilitation_parts":
-            fake.rehabilitation_parts()
-    }
-
-
-def create_test_subject2db(db: Session,
-                       name=fake.name(),
-                       id_card=None,
-                       gender=fake.subject_gender(),
-                       birthday=fake.date_between_dates(date_start="-100y",
-                                                        date_end="-5y"),
-                       rehabilitation_parts=fake.rehabilitation_parts(),
-                       create_time=None) -> SubjectCreate:
-    if create_time is None:
-        subject = SubjectCreate(name=name,
-                                id_card=id_card,
-                                gender=gender,
-                                birthday=birthday,
-                                rehabilitation_parts=rehabilitation_parts)
-    else:
-        subject = SubjectCreate(name=name,
-                                id_card=id_card,
-                                gender=gender,
-                                birthday=birthday,
-                                rehabilitation_parts=rehabilitation_parts,
-                                create_time=create_time)
-    subject = db_rep_sub.create_subject(subject, db)
-    return subject
-
-
-def get_all_subject(db: Session):
-    return db_rep_sub.list_subjects(db=db)
-
-
-def get_random_existing_subject(db: Session):
-    subjects = get_all_subject(db)
-    return random.choice(subjects)

+ 0 - 38
backend/tests/utils/train.py

@@ -1,38 +0,0 @@
-"""testing subjects utils"""
-from sqlalchemy.orm import Session
-
-from db.repository import trains as db_rep_train
-from schemas.trains import TrainCreate
-from utils.utils import fake
-from utils.utils import get_random_position
-
-
-def generate_fake_train_data():
-    return {
-        "position": "左手",
-        "rank": fake.train_rank(),
-        "trial_num": fake.random_digit_not_null(),
-        "start_time": "2022-11-03 00:00",
-        "end_time": "2022-11-04 00:00"
-    }
-
-
-def create_test_train2db(db: Session,
-                      subject,
-                      position=None,
-                      rank=fake.train_rank(),
-                      trial_num=fake.random_digit_not_null(),
-                      start_time="2022-11-03 00:00",
-                      end_time="2022-11-04 00:00",
-                      device_parm=None) -> TrainCreate:
-    if position is None:
-        position = get_random_position(subject)
-    train = TrainCreate(position=position,
-                        rank=rank,
-                        trial_num=trial_num,
-                        start_time=start_time,
-                        end_time=end_time,
-                        device_parm=device_parm,
-                        owner_id=subject.id)
-    train = db_rep_train.create_train(train, db)
-    return train

+ 0 - 72
backend/tests/utils/utils.py

@@ -1,72 +0,0 @@
-"""provide fake data obj"""
-from datetime import datetime, timedelta
-import itertools
-import random
-
-from faker import Faker
-from faker.providers import DynamicProvider
-from sqlalchemy.orm import Session
-
-from db.models.subjects import Subject
-from db.models.trains import Train
-from db.models.hand_peripherals import HandPeripheral
-from db.models.daily_stats import DailyStats
-from db.repository import subjects as db_rep_sub
-from schemas.subjects import SubjectCreate
-
-
-class FakerManager:
-    """init fake obj"""
-
-    def __init__(self, lang="zh-cn"):
-        self.fake = Faker(lang)
-        self.load_provider()
-
-    def load_provider(self):
-        self.fake.add_provider(self.get_gender_provider())
-        self.fake.add_provider(self.get_parts_provider())
-        self.fake.add_provider(self.get_train_rank_provider())
-
-    @staticmethod
-    def get_gender_provider():
-        return DynamicProvider(provider_name="subject_gender",
-                               elements=["男", "女"])
-
-    @staticmethod
-    def get_parts_provider():
-        parts_list = []
-        for num in range(1, 5):
-            parts_list.extend(
-                list(itertools.combinations(["左手", "右手", "左腿", "右腿"], num)))
-        parts_provider = DynamicProvider(provider_name="rehabilitation_parts",
-                                         elements=parts_list)
-        return parts_provider
-
-    @staticmethod
-    def get_train_rank_provider():
-        return DynamicProvider(provider_name="train_rank",
-                               elements=["简单", "中等", "困难"])
-
-
-
-def get_random_position(subject):
-    return random.choice(subject.rehabilitation_parts)
-
-
-def generate_delay_datetime(delay_years: int):
-    today = datetime.today()
-    delay_time = timedelta(days=365*delay_years)
-    delay_datetime = today + delay_time
-    return delay_datetime.strftime("%Y-%m-%d")
-
-
-def clear_db_table(db: Session):
-    db.query(Subject).delete()
-    db.query(Train).delete()
-    db.query(HandPeripheral).delete()
-    db.query(DailyStats).delete()
-    db.commit()
-
-
-fake = FakerManager().fake
-

Some files were not shown because too many files changed in this diff