route_eeg.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. """module apis/version1/route_eeg provide backend apis"""
  2. import asyncio
  3. import logging
  4. from fastapi import APIRouter
  5. from fastapi import Depends
  6. from fastapi import HTTPException, status
  7. from fastapi import WebSocket
  8. from func_timeout import FunctionTimedOut
  9. import numpy as np
  10. # from scipy import signal
  11. from sqlalchemy.orm import Session
  12. from core import utils
  13. from core.mi.eeg_csp import CSPBasedClassifier
  14. from core.mi.eeg_psd import PSDBasedClassifier
  15. from core.mi.utils import SelectedCspChannel
  16. from core.mi.utils import SelectedPsdChannel
  17. from core.mi.pipeline import BaselineModel
  18. from core.sig_chain.device.connector_interface import DataMode
  19. from core.sig_chain.device.connector_interface import Device
  20. from core.sig_chain.pre_process import PreProcessor
  21. from core.sig_chain.pre_process import RealTimeFilterM
  22. from core.sig_chain.sig_receive import Receiver
  23. from db.models.trains import Limbs
  24. from db.models.trains import TrainStatus
  25. from db.repository import subjects as db_rep_sub
  26. from db.repository import trains as db_rep_train
  27. from db.session import get_db
  28. from service import eeg as es
  29. from settings.config import settings
  30. logger = logging.getLogger(__name__)
  31. router = APIRouter()
  32. csp_dc = es.CSPDataCollector()
  33. psd_dc = es.PSDDataCollector(
  34. maxlen=int(settings.TRAIN_PARAMS['rest_stim_duration'] / 1000))
  35. psd_clf = PSDBasedClassifier()
  36. csp_clf = CSPBasedClassifier()
  37. train_finish_flag = False
  38. pipeline = BaselineModel("static/models/bp-baseline.pkl")
  39. @router.get("/train-configs")
  40. def get_train_configs():
  41. return settings.TRAIN_PARAMS
  42. @router.get("/eeg-edf-set-header")
  43. def eeg_edf_set_header(subject_id: str = None,
  44. train_id: int = None,
  45. task_per_run: int = None,
  46. db: Session = Depends(get_db)):
  47. """创建BDF数据的数据头
  48. Args:
  49. subject_id (str, optional): 患者ID. Defaults to None.
  50. train_id (int, optional): 训练ID. Defaults to None.
  51. db (Session, optional): 数据库. Defaults to Depends(get_db).
  52. Returns:
  53. 1: 创建成功
  54. """
  55. path = utils.create_data_dir(subject_id, train_id)
  56. subject = db_rep_sub.retrieve_subject_by_id(id=subject_id, db=db)
  57. train = db_rep_train.retrieve_train(id=train_id, db=db)
  58. position_name = "test"
  59. if Limbs.get_item_name(train.position) is not None:
  60. position_name = Limbs.get_item_name(train.position).lower()
  61. # pylint: disable=line-too-long
  62. filename = f"{subject.id_card}_{train.start_time.strftime('%Y%m%d%H%M%S')}_{position_name}.bdf"
  63. # pylint: enable=line-too-long
  64. receiver = Receiver()
  65. receiver.connector.set_saver()
  66. receiver.connector.saver.set_edf_header(subject, filename, task_per_run,
  67. path)
  68. update_dict = {"train_status": TrainStatus.TRAINING}
  69. db_rep_train.partial_update_train_by_id(train_id, update_dict, db)
  70. return 1
  71. @router.get("/eeg-edf-mark")
  72. def eeg_edf_mark(time_seconds: int, mark: str):
  73. """数据打标签
  74. Args:
  75. time_seconds (int): 打标签的时间点
  76. mark (str): 标记
  77. Returns:
  78. 1: 成功
  79. """
  80. receiver = Receiver()
  81. receiver.connector.saver.edf_data_mark(time_seconds, mark)
  82. return 1
  83. @router.get("/eeg-device-connect")
  84. def eeg_device_connect():
  85. """脑电设备连接
  86. """
  87. device = settings.config["test_parameter"]["device"]
  88. receiver = Receiver()
  89. if device == "faker":
  90. config_info = settings.config.get("faker_eeg_config")
  91. receiver.select_connector(Device.FAKER,
  92. config_info.get("buffer_plot_size_seconds"),
  93. config_info)
  94. elif device == "pony":
  95. config_info = settings.config.get("pony_eeg_config")
  96. receiver.select_connector(Device.PONY,
  97. config_info.get("buffer_plot_size_seconds"),
  98. config_info)
  99. elif device == "neo":
  100. config_info = settings.config.get("neo_eeg_config")
  101. receiver.select_connector(Device.NEO,
  102. config_info.get("buffer_plot_size_seconds"),
  103. config_info)
  104. else:
  105. raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  106. detail="Invalid device name")
  107. psd_clf.update_params(receiver.connector.sample_params.sample_rate)
  108. success = receiver.setup_connector()
  109. if success:
  110. return {"msg": success}
  111. else:
  112. raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  113. detail="EEG device connected failed")
  114. # 在get_wave_from_buffer直接获取数据, 后续考虑删除
  115. @router.get("/data-buffer")
  116. def start_receive_wave():
  117. """获取数据到buffer
  118. """
  119. receiver = Receiver()
  120. try:
  121. receiver.start_receive_wave()
  122. except AssertionError as exc:
  123. raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  124. detail="Start receive wave failed") from exc
  125. return {"msg": "success"}
  126. @router.websocket("/data")
  127. # pylint: disable=redundant-returns-doc
  128. # pylint: disable=missing-raises-doc
  129. async def get_wave_from_buffer(websocket: WebSocket):
  130. # pylint: enable=redundant-returns-doc
  131. # pylint: enable=missing-raises-doc
  132. """获取脑电数据
  133. Returns:
  134. JSON: 返回的脑电数据,(通道数*采样率)
  135. raises HTTPException: status状态错误(500)
  136. raises HTTPException: status状态错误(408)
  137. """
  138. receiver = Receiver()
  139. filter_m_high = RealTimeFilterM.init_eeg(
  140. 0, receiver.connector.sample_params.channel_count)
  141. await websocket.accept()
  142. while True:
  143. # 时间参数要比plot buffer小, 可以考虑以plot buffer的二分之一设置
  144. await asyncio.sleep(receiver.buffer_plot.package_len / 2)
  145. try:
  146. # await websocket.receive_text()
  147. ret = None
  148. timestamp = None
  149. receiver.connector.receive_wave()
  150. data_from_buffer = receiver.get_data_from_buffer("plot")
  151. if data_from_buffer["status"] == "ok":
  152. raw_waves = data_from_buffer["data"]
  153. timestamp = data_from_buffer["timestamp"]
  154. #TODO: 预处理的相关参数设置
  155. resampled_waves = PreProcessor.resample_direct(
  156. raw_waves, settings.config["frontend_plot"]["sample_rate"])
  157. _, samples = resampled_waves.get_data().shape
  158. m_yy = np.zeros_like(resampled_waves.get_data(),
  159. dtype=np.float64)
  160. for ii in range(samples):
  161. xn = resampled_waves.get_data()[:, ii]
  162. m_yy[:, ii] = filter_m_high.filter(xn)
  163. ret = m_yy.tolist()
  164. # await websocket.send_json(ret)
  165. # ret = raw_waves.get_data().tolist()
  166. except RuntimeError as exc:
  167. raise HTTPException(
  168. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) from exc
  169. except FunctionTimedOut as exc:
  170. raise HTTPException(
  171. status_code=status.HTTP_408_REQUEST_TIMEOUT) from exc
  172. await websocket.send_json({"timestamp": timestamp, "eegdata": ret})
  173. @router.get("/wave-mode-connect")
  174. def wave_mode_connect():
  175. """阻抗模式连接
  176. """
  177. receiver = Receiver()
  178. receiver.setup_receive_mode(DataMode.WAVE)
  179. @router.get("/impedance-model-connect")
  180. def impedance_mode_connect():
  181. """阻抗模式连接
  182. """
  183. receiver = Receiver()
  184. receiver.setup_receive_mode(DataMode.IMPEDANCE)
  185. @router.get("/impedance-data")
  186. # pylint: disable=missing-raises-doc
  187. def get_impedance():
  188. # pylint: enable=missing-raises-doc
  189. """阻抗数据获取
  190. Returns:
  191. JSON: 阻抗数据
  192. Raises: HTTPException(503)
  193. """
  194. receiver = Receiver()
  195. try:
  196. impedance = receiver.receive_impedance()
  197. except AssertionError as exc:
  198. raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  199. detail="Receive impedance failed") from exc
  200. return {"impedance": impedance}
  201. @router.get("/eeg-model-close")
  202. def eeg_mode_close():
  203. """脑电数据模式关闭
  204. """
  205. receiver = Receiver()
  206. receiver.stop_receive()
  207. # TODO: 两个close 合并
  208. @router.get("/impedance-model-close")
  209. def impedance_mode_close():
  210. """阻抗模式关闭
  211. """
  212. receiver = Receiver()
  213. receiver.stop_receive()
  214. @router.get("/initial-rest-state-run")
  215. def initial_rest_state_run(position: str, duration: int):
  216. """训练最开始的静息处理
  217. Args:
  218. position (str): 训练部位
  219. duration (int): 静息态持续时间
  220. """
  221. # logger.debug("训练部位:%s", position)
  222. receiver = Receiver()
  223. # TODO: 放到reset部分?
  224. csp_dc.set_collected_channel(
  225. SelectedCspChannel(receiver.connector.device).get_channel_ids(
  226. position, receiver.connector.sample_params.channel_labels))
  227. psd_dc.set_collected_channel(
  228. SelectedPsdChannel(receiver.connector.device).get_channel_ids(
  229. position, receiver.connector.sample_params.channel_labels))
  230. train_success = es.initial_rest_process(receiver, psd_dc, duration, psd_clf)
  231. return {"train_success": train_success}
  232. @router.get("/mi-state-run")
  233. def mi_state_run(current_round: int, duration: int, sample_duration:int):
  234. """一个任务的mi过程处理
  235. Args:
  236. current_round (int): 当前轮数.
  237. duration (int): 运动想象是时长(second).
  238. sample_duration (int): 用于训练CSP的数据样本长度(second).
  239. Returns:
  240. list: predicts, 分类结果, 1是运动想象, 0是静息
  241. """
  242. assert duration >= sample_duration, \
  243. "Duration >= sample_duration not satisfied!"
  244. receiver = Receiver()
  245. predicts = es.one_task_mi_process(receiver, psd_dc, csp_dc, current_round,
  246. duration, sample_duration, psd_clf,
  247. csp_clf)
  248. return {"predicts": predicts}
  249. @router.get("/rest-state-run")
  250. def rest_state_run(tasks_per_round: int, duration: int, sample_duration: int):
  251. """一个任务的rest过程处理
  252. Args:
  253. tasks_per_round (int): 每轮的任务个数.
  254. duration (int): 休息时长(second).
  255. sample_duration (int): 用于训练CSP的数据样本长度(second).
  256. Returns:
  257. _type_: _description_
  258. """
  259. assert duration >= sample_duration, "Duration >= sample_duration not satisfied!"
  260. receiver = Receiver()
  261. es.one_task_rest_process(receiver, csp_dc, tasks_per_round, duration,
  262. sample_duration, csp_clf)
  263. return {"success"}
  264. @router.get("/mi-test-run")
  265. def mi_test_run(current_round: int):
  266. """一个任务的mi过程处理
  267. Args:
  268. current_round (int): 当前轮数.
  269. Returns:
  270. list: predicts, 分类结果, 1是运动想象, 0是静息
  271. """
  272. receiver = Receiver()
  273. receiver.connector.receive_wave()
  274. data_from_buffer = receiver.get_data_from_buffer("classify_online")
  275. if data_from_buffer["status"] == "ok":
  276. predict = pipeline.smoothed_decision(data_from_buffer)
  277. timestamps = data_from_buffer["timestamp"]
  278. receiver.connector.saver.edf_data_mark(timestamps[0], str(predict))
  279. else:
  280. predict = None
  281. return {"predict": predict}
  282. @router.get("/eeg-pipeline-reset")
  283. def eeg_pipeline_reset():
  284. """每次判成功后reset pipeline buffer
  285. """
  286. pipeline.reset_buffer()
  287. @router.get("/eeg-clf-reset")
  288. def eeg_clf_reset():
  289. """开始训练时要重置参数
  290. """
  291. global csp_clf
  292. global psd_clf
  293. global csp_dc
  294. global psd_dc
  295. csp_clf = CSPBasedClassifier()
  296. receiver = Receiver()
  297. psd_clf = PSDBasedClassifier(receiver.connector.sample_params.sample_rate)
  298. csp_dc = es.CSPDataCollector()
  299. psd_dc = es.PSDDataCollector(
  300. maxlen=int(settings.TRAIN_PARAMS['rest_stim_duration'] / 1000))
  301. @router.get("/set-train-finish-flag")
  302. def set_train_finish_flag(flag):
  303. global train_finish_flag
  304. train_finish_flag = flag
  305. return "设置成功"
  306. @router.get("/get-train-finish-flag")
  307. def get_train_finish_flag():
  308. return {"train_finish_flag": train_finish_flag}
  309. @router.get("/restart-fake-data")
  310. def restart_fake_data():
  311. receiver = Receiver()
  312. if receiver.connector.device == Device.FAKER:
  313. receiver.reset_wave()
  314. return {"status": 1}