test_erds.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. """ RED/ERS 单元测试 """
  2. import os
  3. from core.sig_chain.sig_reader import Reader
  4. from core.mi.eeg_erds import ErdErs
  5. TEST_DATA_PATH = "tests/data/"
  6. BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "eeg_raw_data.bdf")
  7. ERDS_FILE_PATH = os.path.join(TEST_DATA_PATH, "erds.png")
  8. TFR_ERDS_FILE_PATH = os.path.join(TEST_DATA_PATH, "tfr_erds.png")
  9. def setup_module():
  10. if not os.path.exists(TEST_DATA_PATH):
  11. os.makedirs(TEST_DATA_PATH)
  12. def teardown_module():
  13. if os.path.exists(ERDS_FILE_PATH):
  14. os.remove(ERDS_FILE_PATH)
  15. if os.path.exists(TFR_ERDS_FILE_PATH):
  16. os.remove(TFR_ERDS_FILE_PATH)
  17. def test_main():
  18. # ERD/ERS
  19. # 左右手
  20. # [ "C3", "C4" ]
  21. ch_names = ["C3", "Cz", "C4"]
  22. reader = Reader()
  23. raw = reader.read(BDF_FILE_PATH, ch_names)
  24. raw.annotations.rename({
  25. "trainSuccess": "mi",
  26. "trainFailed": "mi",
  27. "restState": "rest"
  28. })
  29. raw.resample(200)
  30. channels = ("C3", "C4")
  31. erds = ErdErs(-1, 1)
  32. epochs, event_id_pick = erds.get_epochs(raw, channels)
  33. tfr = erds.process(epochs, (-1, 0), mode="percent")
  34. erds.draw_image(tfr, channels, ERDS_FILE_PATH)
  35. tfr = erds.process(epochs, (-1, 0))
  36. erds.draw_tfr_image(tfr, event_id_pick, channels, TFR_ERDS_FILE_PATH)