123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386 |
- """module apis/version1/route_eeg provide backend apis"""
- import asyncio
- import logging
- from fastapi import APIRouter
- from fastapi import Depends
- from fastapi import HTTPException, status
- from fastapi import WebSocket
- from func_timeout import FunctionTimedOut
- import numpy as np
- # from scipy import signal
- from sqlalchemy.orm import Session
- from core import utils
- from core.mi.eeg_csp import CSPBasedClassifier
- from core.mi.eeg_psd import PSDBasedClassifier
- from core.mi.utils import SelectedCspChannel
- from core.mi.utils import SelectedPsdChannel
- from core.mi.pipeline import BaselineModel
- from core.sig_chain.device.connector_interface import DataMode
- from core.sig_chain.device.connector_interface import Device
- from core.sig_chain.pre_process import PreProcessor
- from core.sig_chain.pre_process import RealTimeFilterM
- from core.sig_chain.sig_receive import Receiver
- from db.models.trains import Limbs
- from db.models.trains import TrainStatus
- from db.repository import subjects as db_rep_sub
- from db.repository import trains as db_rep_train
- from db.session import get_db
- from service import eeg as es
- from settings.config import settings
- logger = logging.getLogger(__name__)
- router = APIRouter()
- csp_dc = es.CSPDataCollector()
- psd_dc = es.PSDDataCollector(
- maxlen=int(settings.TRAIN_PARAMS['rest_stim_duration'] / 1000))
- psd_clf = PSDBasedClassifier()
- csp_clf = CSPBasedClassifier()
- train_finish_flag = False
- pipeline = BaselineModel("static/models/bp-baseline.pkl")
- @router.get("/train-configs")
- def get_train_configs():
- return settings.TRAIN_PARAMS
- @router.get("/eeg-edf-set-header")
- def eeg_edf_set_header(subject_id: str = None,
- train_id: int = None,
- task_per_run: int = None,
- db: Session = Depends(get_db)):
- """创建BDF数据的数据头
- Args:
- subject_id (str, optional): 患者ID. Defaults to None.
- train_id (int, optional): 训练ID. Defaults to None.
- db (Session, optional): 数据库. Defaults to Depends(get_db).
- Returns:
- 1: 创建成功
- """
- path = utils.create_data_dir(subject_id, train_id)
- subject = db_rep_sub.retrieve_subject_by_id(id=subject_id, db=db)
- train = db_rep_train.retrieve_train(id=train_id, db=db)
- position_name = "test"
- if Limbs.get_item_name(train.position) is not None:
- position_name = Limbs.get_item_name(train.position).lower()
- # pylint: disable=line-too-long
- filename = f"{subject.id_card}_{train.start_time.strftime('%Y%m%d%H%M%S')}_{position_name}.bdf"
- # pylint: enable=line-too-long
- receiver = Receiver()
- receiver.connector.set_saver()
- receiver.connector.saver.set_edf_header(subject, filename, task_per_run,
- path)
- update_dict = {"train_status": TrainStatus.TRAINING}
- db_rep_train.partial_update_train_by_id(train_id, update_dict, db)
- return 1
- @router.get("/eeg-edf-mark")
- def eeg_edf_mark(time_seconds: int, mark: str):
- """数据打标签
- Args:
- time_seconds (int): 打标签的时间点
- mark (str): 标记
- Returns:
- 1: 成功
- """
- receiver = Receiver()
- receiver.connector.saver.edf_data_mark(time_seconds, mark)
- return 1
- @router.get("/eeg-device-connect")
- def eeg_device_connect():
- """脑电设备连接
- """
- device = settings.config["test_parameter"]["device"]
- receiver = Receiver()
- if device == "faker":
- config_info = settings.config.get("faker_eeg_config")
- receiver.select_connector(Device.FAKER,
- config_info.get("buffer_plot_size_seconds"),
- config_info)
- elif device == "pony":
- config_info = settings.config.get("pony_eeg_config")
- receiver.select_connector(Device.PONY,
- config_info.get("buffer_plot_size_seconds"),
- config_info)
- elif device == "neo":
- config_info = settings.config.get("neo_eeg_config")
- receiver.select_connector(Device.NEO,
- config_info.get("buffer_plot_size_seconds"),
- config_info)
- else:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="Invalid device name")
- psd_clf.update_params(receiver.connector.sample_params.sample_rate)
- success = receiver.setup_connector()
- if success:
- return {"msg": success}
- else:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="EEG device connected failed")
- # 在get_wave_from_buffer直接获取数据, 后续考虑删除
- @router.get("/data-buffer")
- def start_receive_wave():
- """获取数据到buffer
- """
- receiver = Receiver()
- try:
- receiver.start_receive_wave()
- except AssertionError as exc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="Start receive wave failed") from exc
- return {"msg": "success"}
- @router.websocket("/data")
- # pylint: disable=redundant-returns-doc
- # pylint: disable=missing-raises-doc
- async def get_wave_from_buffer(websocket: WebSocket):
- # pylint: enable=redundant-returns-doc
- # pylint: enable=missing-raises-doc
- """获取脑电数据
- Returns:
- JSON: 返回的脑电数据,(通道数*采样率)
- raises HTTPException: status状态错误(500)
- raises HTTPException: status状态错误(408)
- """
- receiver = Receiver()
- filter_m_high = RealTimeFilterM.init_eeg(
- 0, receiver.connector.sample_params.channel_count)
- await websocket.accept()
- while True:
- # 时间参数要比plot buffer小, 可以考虑以plot buffer的二分之一设置
- await asyncio.sleep(receiver.buffer_plot.package_len / 2)
- try:
- # await websocket.receive_text()
- ret = None
- timestamp = None
- receiver.connector.receive_wave()
- data_from_buffer = receiver.get_data_from_buffer("plot")
- if data_from_buffer["status"] == "ok":
- raw_waves = data_from_buffer["data"]
- timestamp = data_from_buffer["timestamp"]
- #TODO: 预处理的相关参数设置
- resampled_waves = PreProcessor.resample_direct(
- raw_waves, settings.config["frontend_plot"]["sample_rate"])
- _, samples = resampled_waves.get_data().shape
- m_yy = np.zeros_like(resampled_waves.get_data(),
- dtype=np.float64)
- for ii in range(samples):
- xn = resampled_waves.get_data()[:, ii]
- m_yy[:, ii] = filter_m_high.filter(xn)
- ret = m_yy.tolist()
- # await websocket.send_json(ret)
- # ret = raw_waves.get_data().tolist()
- except RuntimeError as exc:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) from exc
- except FunctionTimedOut as exc:
- raise HTTPException(
- status_code=status.HTTP_408_REQUEST_TIMEOUT) from exc
- await websocket.send_json({"timestamp": timestamp, "eegdata": ret})
- @router.get("/wave-mode-connect")
- def wave_mode_connect():
- """阻抗模式连接
- """
- receiver = Receiver()
- receiver.setup_receive_mode(DataMode.WAVE)
- @router.get("/impedance-model-connect")
- def impedance_mode_connect():
- """阻抗模式连接
- """
- receiver = Receiver()
- receiver.setup_receive_mode(DataMode.IMPEDANCE)
- @router.get("/impedance-data")
- # pylint: disable=missing-raises-doc
- def get_impedance():
- # pylint: enable=missing-raises-doc
- """阻抗数据获取
- Returns:
- JSON: 阻抗数据
- Raises: HTTPException(503)
- """
- receiver = Receiver()
- try:
- impedance = receiver.receive_impedance()
- except AssertionError as exc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="Receive impedance failed") from exc
- return {"impedance": impedance}
- @router.get("/eeg-model-close")
- def eeg_mode_close():
- """脑电数据模式关闭
- """
- receiver = Receiver()
- receiver.stop_receive()
- # TODO: 两个close 合并
- @router.get("/impedance-model-close")
- def impedance_mode_close():
- """阻抗模式关闭
- """
- receiver = Receiver()
- receiver.stop_receive()
- @router.get("/initial-rest-state-run")
- def initial_rest_state_run(position: str, duration: int):
- """训练最开始的静息处理
- Args:
- position (str): 训练部位
- duration (int): 静息态持续时间
- """
- # logger.debug("训练部位:%s", position)
- receiver = Receiver()
- # TODO: 放到reset部分?
- csp_dc.set_collected_channel(
- SelectedCspChannel(receiver.connector.device).get_channel_ids(
- position, receiver.connector.sample_params.channel_labels))
- psd_dc.set_collected_channel(
- SelectedPsdChannel(receiver.connector.device).get_channel_ids(
- position, receiver.connector.sample_params.channel_labels))
- train_success = es.initial_rest_process(receiver, psd_dc, duration, psd_clf)
- return {"train_success": train_success}
- @router.get("/mi-state-run")
- def mi_state_run(current_round: int, duration: int, sample_duration:int):
- """一个任务的mi过程处理
- Args:
- current_round (int): 当前轮数.
- duration (int): 运动想象是时长(second).
- sample_duration (int): 用于训练CSP的数据样本长度(second).
- Returns:
- list: predicts, 分类结果, 1是运动想象, 0是静息
- """
- assert duration >= sample_duration, \
- "Duration >= sample_duration not satisfied!"
- receiver = Receiver()
- predicts = es.one_task_mi_process(receiver, psd_dc, csp_dc, current_round,
- duration, sample_duration, psd_clf,
- csp_clf)
- return {"predicts": predicts}
- @router.get("/rest-state-run")
- def rest_state_run(tasks_per_round: int, duration: int, sample_duration: int):
- """一个任务的rest过程处理
- Args:
- tasks_per_round (int): 每轮的任务个数.
- duration (int): 休息时长(second).
- sample_duration (int): 用于训练CSP的数据样本长度(second).
- Returns:
- _type_: _description_
- """
- assert duration >= sample_duration, "Duration >= sample_duration not satisfied!"
- receiver = Receiver()
- es.one_task_rest_process(receiver, csp_dc, tasks_per_round, duration,
- sample_duration, csp_clf)
- return {"success"}
- @router.get("/mi-test-run")
- def mi_test_run(current_round: int):
- """一个任务的mi过程处理
- Args:
- current_round (int): 当前轮数.
- Returns:
- list: predicts, 分类结果, 1是运动想象, 0是静息
- """
- receiver = Receiver()
- receiver.connector.receive_wave()
- data_from_buffer = receiver.get_data_from_buffer("classify_online")
- if data_from_buffer["status"] == "ok":
- predict = pipeline.smoothed_decision(data_from_buffer)
- timestamps = data_from_buffer["timestamp"]
- receiver.connector.saver.edf_data_mark(timestamps[0], str(predict))
- else:
- predict = None
- return {"predict": predict}
- @router.get("/eeg-pipeline-reset")
- def eeg_pipeline_reset():
- """每次判成功后reset pipeline buffer
- """
- pipeline.reset_buffer()
- @router.get("/eeg-clf-reset")
- def eeg_clf_reset():
- """开始训练时要重置参数
- """
- global csp_clf
- global psd_clf
- global csp_dc
- global psd_dc
- csp_clf = CSPBasedClassifier()
- receiver = Receiver()
- psd_clf = PSDBasedClassifier(receiver.connector.sample_params.sample_rate)
- csp_dc = es.CSPDataCollector()
- psd_dc = es.PSDDataCollector(
- maxlen=int(settings.TRAIN_PARAMS['rest_stim_duration'] / 1000))
- @router.get("/set-train-finish-flag")
- def set_train_finish_flag(flag):
- global train_finish_flag
- train_finish_flag = flag
- return "设置成功"
- @router.get("/get-train-finish-flag")
- def get_train_finish_flag():
- return {"train_finish_flag": train_finish_flag}
- @router.get("/restart-fake-data")
- def restart_fake_data():
- receiver = Receiver()
- if receiver.connector.device == Device.FAKER:
- receiver.reset_wave()
- return {"status": 1}
|