"""module apis/version1/route_eeg provide backend apis""" import asyncio import logging from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException, status from fastapi import WebSocket from func_timeout import FunctionTimedOut import numpy as np # from scipy import signal from sqlalchemy.orm import Session from core import utils from core.mi.eeg_csp import CSPBasedClassifier from core.mi.eeg_psd import PSDBasedClassifier from core.mi.utils import SelectedCspChannel from core.mi.utils import SelectedPsdChannel from core.mi.pipeline import BaselineModel from core.sig_chain.device.connector_interface import DataMode from core.sig_chain.device.connector_interface import Device from core.sig_chain.pre_process import PreProcessor from core.sig_chain.pre_process import RealTimeFilterM from core.sig_chain.sig_receive import Receiver from db.models.trains import Limbs from db.models.trains import TrainStatus from db.repository import subjects as db_rep_sub from db.repository import trains as db_rep_train from db.session import get_db from service import eeg as es from settings.config import settings logger = logging.getLogger(__name__) router = APIRouter() csp_dc = es.CSPDataCollector() psd_dc = es.PSDDataCollector( maxlen=int(settings.TRAIN_PARAMS['rest_stim_duration'] / 1000)) psd_clf = PSDBasedClassifier() csp_clf = CSPBasedClassifier() train_finish_flag = False pipeline = BaselineModel("static/models/bp-baseline.pkl") @router.get("/train-configs") def get_train_configs(): return settings.TRAIN_PARAMS @router.get("/eeg-edf-set-header") def eeg_edf_set_header(subject_id: str = None, train_id: int = None, task_per_run: int = None, db: Session = Depends(get_db)): """创建BDF数据的数据头 Args: subject_id (str, optional): 患者ID. Defaults to None. train_id (int, optional): 训练ID. Defaults to None. db (Session, optional): 数据库. Defaults to Depends(get_db). Returns: 1: 创建成功 """ path = utils.create_data_dir(subject_id, train_id) subject = db_rep_sub.retrieve_subject_by_id(id=subject_id, db=db) train = db_rep_train.retrieve_train(id=train_id, db=db) position_name = "test" if Limbs.get_item_name(train.position) is not None: position_name = Limbs.get_item_name(train.position).lower() # pylint: disable=line-too-long filename = f"{subject.id_card}_{train.start_time.strftime('%Y%m%d%H%M%S')}_{position_name}.bdf" # pylint: enable=line-too-long receiver = Receiver() receiver.connector.set_saver() receiver.connector.saver.set_edf_header(subject, filename, task_per_run, path) update_dict = {"train_status": TrainStatus.TRAINING} db_rep_train.partial_update_train_by_id(train_id, update_dict, db) return 1 @router.get("/eeg-edf-mark") def eeg_edf_mark(time_seconds: int, mark: str): """数据打标签 Args: time_seconds (int): 打标签的时间点 mark (str): 标记 Returns: 1: 成功 """ receiver = Receiver() receiver.connector.saver.edf_data_mark(time_seconds, mark) return 1 @router.get("/eeg-device-connect") def eeg_device_connect(): """脑电设备连接 """ device = settings.config["test_parameter"]["device"] receiver = Receiver() if device == "faker": config_info = settings.config.get("faker_eeg_config") receiver.select_connector(Device.FAKER, config_info.get("buffer_plot_size_seconds"), config_info) elif device == "pony": config_info = settings.config.get("pony_eeg_config") receiver.select_connector(Device.PONY, config_info.get("buffer_plot_size_seconds"), config_info) elif device == "neo": config_info = settings.config.get("neo_eeg_config") receiver.select_connector(Device.NEO, config_info.get("buffer_plot_size_seconds"), config_info) else: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid device name") psd_clf.update_params(receiver.connector.sample_params.sample_rate) success = receiver.setup_connector() if success: return {"msg": success} else: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="EEG device connected failed") # 在get_wave_from_buffer直接获取数据, 后续考虑删除 @router.get("/data-buffer") def start_receive_wave(): """获取数据到buffer """ receiver = Receiver() try: receiver.start_receive_wave() except AssertionError as exc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Start receive wave failed") from exc return {"msg": "success"} @router.websocket("/data") # pylint: disable=redundant-returns-doc # pylint: disable=missing-raises-doc async def get_wave_from_buffer(websocket: WebSocket): # pylint: enable=redundant-returns-doc # pylint: enable=missing-raises-doc """获取脑电数据 Returns: JSON: 返回的脑电数据,(通道数*采样率) raises HTTPException: status状态错误(500) raises HTTPException: status状态错误(408) """ receiver = Receiver() filter_m_high = RealTimeFilterM.init_eeg( 0, receiver.connector.sample_params.channel_count) await websocket.accept() while True: # 时间参数要比plot buffer小, 可以考虑以plot buffer的二分之一设置 await asyncio.sleep(receiver.buffer_plot.package_len / 2) try: # await websocket.receive_text() ret = None timestamp = None receiver.connector.receive_wave() data_from_buffer = receiver.get_data_from_buffer("plot") if data_from_buffer["status"] == "ok": raw_waves = data_from_buffer["data"] timestamp = data_from_buffer["timestamp"] #TODO: 预处理的相关参数设置 resampled_waves = PreProcessor.resample_direct( raw_waves, settings.config["frontend_plot"]["sample_rate"]) _, samples = resampled_waves.get_data().shape m_yy = np.zeros_like(resampled_waves.get_data(), dtype=np.float64) for ii in range(samples): xn = resampled_waves.get_data()[:, ii] m_yy[:, ii] = filter_m_high.filter(xn) ret = m_yy.tolist() # await websocket.send_json(ret) # ret = raw_waves.get_data().tolist() except RuntimeError as exc: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) from exc except FunctionTimedOut as exc: raise HTTPException( status_code=status.HTTP_408_REQUEST_TIMEOUT) from exc await websocket.send_json({"timestamp": timestamp, "eegdata": ret}) @router.get("/wave-mode-connect") def wave_mode_connect(): """阻抗模式连接 """ receiver = Receiver() receiver.setup_receive_mode(DataMode.WAVE) @router.get("/impedance-model-connect") def impedance_mode_connect(): """阻抗模式连接 """ receiver = Receiver() receiver.setup_receive_mode(DataMode.IMPEDANCE) @router.get("/impedance-data") # pylint: disable=missing-raises-doc def get_impedance(): # pylint: enable=missing-raises-doc """阻抗数据获取 Returns: JSON: 阻抗数据 Raises: HTTPException(503) """ receiver = Receiver() try: impedance = receiver.receive_impedance() except AssertionError as exc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Receive impedance failed") from exc return {"impedance": impedance} @router.get("/eeg-model-close") def eeg_mode_close(): """脑电数据模式关闭 """ receiver = Receiver() receiver.stop_receive() # TODO: 两个close 合并 @router.get("/impedance-model-close") def impedance_mode_close(): """阻抗模式关闭 """ receiver = Receiver() receiver.stop_receive() @router.get("/initial-rest-state-run") def initial_rest_state_run(position: str, duration: int): """训练最开始的静息处理 Args: position (str): 训练部位 duration (int): 静息态持续时间 """ # logger.debug("训练部位:%s", position) receiver = Receiver() # TODO: 放到reset部分? csp_dc.set_collected_channel( SelectedCspChannel(receiver.connector.device).get_channel_ids( position, receiver.connector.sample_params.channel_labels)) psd_dc.set_collected_channel( SelectedPsdChannel(receiver.connector.device).get_channel_ids( position, receiver.connector.sample_params.channel_labels)) train_success = es.initial_rest_process(receiver, psd_dc, duration, psd_clf) return {"train_success": train_success} @router.get("/mi-state-run") def mi_state_run(current_round: int, duration: int, sample_duration:int): """一个任务的mi过程处理 Args: current_round (int): 当前轮数. duration (int): 运动想象是时长(second). sample_duration (int): 用于训练CSP的数据样本长度(second). Returns: list: predicts, 分类结果, 1是运动想象, 0是静息 """ assert duration >= sample_duration, \ "Duration >= sample_duration not satisfied!" receiver = Receiver() predicts = es.one_task_mi_process(receiver, psd_dc, csp_dc, current_round, duration, sample_duration, psd_clf, csp_clf) return {"predicts": predicts} @router.get("/rest-state-run") def rest_state_run(tasks_per_round: int, duration: int, sample_duration: int): """一个任务的rest过程处理 Args: tasks_per_round (int): 每轮的任务个数. duration (int): 休息时长(second). sample_duration (int): 用于训练CSP的数据样本长度(second). Returns: _type_: _description_ """ assert duration >= sample_duration, "Duration >= sample_duration not satisfied!" receiver = Receiver() es.one_task_rest_process(receiver, csp_dc, tasks_per_round, duration, sample_duration, csp_clf) return {"success"} @router.get("/mi-test-run") def mi_test_run(current_round: int): """一个任务的mi过程处理 Args: current_round (int): 当前轮数. Returns: list: predicts, 分类结果, 1是运动想象, 0是静息 """ receiver = Receiver() receiver.connector.receive_wave() data_from_buffer = receiver.get_data_from_buffer("classify_online") if data_from_buffer["status"] == "ok": predict = pipeline.smoothed_decision(data_from_buffer) timestamps = data_from_buffer["timestamp"] receiver.connector.saver.edf_data_mark(timestamps[0], str(predict)) else: predict = None return {"predict": predict} @router.get("/eeg-pipeline-reset") def eeg_pipeline_reset(): """每次判成功后reset pipeline buffer """ pipeline.reset_buffer() @router.get("/eeg-clf-reset") def eeg_clf_reset(): """开始训练时要重置参数 """ global csp_clf global psd_clf global csp_dc global psd_dc csp_clf = CSPBasedClassifier() receiver = Receiver() psd_clf = PSDBasedClassifier(receiver.connector.sample_params.sample_rate) csp_dc = es.CSPDataCollector() psd_dc = es.PSDDataCollector( maxlen=int(settings.TRAIN_PARAMS['rest_stim_duration'] / 1000)) @router.get("/set-train-finish-flag") def set_train_finish_flag(flag): global train_finish_flag train_finish_flag = flag return "设置成功" @router.get("/get-train-finish-flag") def get_train_finish_flag(): return {"train_finish_flag": train_finish_flag} @router.get("/restart-fake-data") def restart_fake_data(): receiver = Receiver() if receiver.connector.device == Device.FAKER: receiver.reset_wave() return {"status": 1}