test_csp.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """ CSP 单元测试 """
  2. # pylint: disable=missing-class-docstring
  3. import os
  4. import numpy as np
  5. from core.sig_chain.sig_reader import Reader
  6. from core.mi.eeg_csp import CspOffline
  7. from core.mi.eeg_csp import CSPBasedClassifier
  8. TEST_DATA_PATH = "tests/data/"
  9. PONY_BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "eeg_raw_data.bdf")
  10. NEO_BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "neo_eeg_raw_data.bdf")
  11. PONY_CSP_FILE_PATH = os.path.join(TEST_DATA_PATH, "pony_csp.png")
  12. NEO_CSP_FILE_PATH = os.path.join(TEST_DATA_PATH, "neo_csp.png")
  13. def setup_module():
  14. if not os.path.exists(TEST_DATA_PATH):
  15. os.makedirs(TEST_DATA_PATH)
  16. def teardown_module():
  17. if os.path.exists(PONY_CSP_FILE_PATH):
  18. os.remove(PONY_CSP_FILE_PATH)
  19. if os.path.exists(NEO_CSP_FILE_PATH):
  20. os.remove(NEO_CSP_FILE_PATH)
  21. class TestCSPBasedClassifier():
  22. @classmethod
  23. def setup_class(cls):
  24. ch_names = [
  25. "T6", "P4", "Pz", "M2", "F8", "F4", "Fp1", "Cz", "M1", "F7", "F3",
  26. "C3", "T3", "A1", "Oz", "O1", "O2", "Fz", "C4", "T4", "Fp2", "A2",
  27. "T5", "P3"
  28. ]
  29. reader = Reader()
  30. cls.raw = reader.read(PONY_BDF_FILE_PATH, tuple(ch_names))
  31. cls.raw.annotations.rename({
  32. "trainSuccess": "mi",
  33. "trainFailed": "mi",
  34. "restState": "rest"
  35. })
  36. cls.raw.annotations.duration += 0.999
  37. def setup_method(self):
  38. self.clf = CSPBasedClassifier()
  39. def test_train_and_test_with_same_sample_length(self):
  40. ch_names = ["C3", "Cz", "C4"]
  41. csp_offline = CspOffline()
  42. epochs, _ = csp_offline.get_epochs(self.raw, tuple(ch_names))
  43. labels = epochs.events[:, -1]
  44. self.clf.fit(epochs.get_data(), labels)
  45. predicts = self.clf.predict(epochs.get_data())
  46. acc = np.sum(predicts == labels) / predicts.size
  47. assert self.clf.is_trained
  48. assert acc >= 0.8
  49. def test_train_and_test_with_different_sample_length(self):
  50. ch_names = ["C3", "Cz", "C4"]
  51. csp_offline1 = CspOffline()
  52. epochs1, _ = csp_offline1.get_epochs(self.raw, tuple(ch_names))
  53. csp_offline3 = CspOffline()
  54. csp_offline3.tmax = 3
  55. epochs3, _ = csp_offline3.get_epochs(self.raw, tuple(ch_names))
  56. labels = epochs1.events[:, -1]
  57. self.clf.fit(epochs3.get_data(), labels)
  58. predicts = self.clf.predict(epochs1.get_data())
  59. acc = np.sum(predicts == labels) / predicts.size
  60. assert self.clf.is_trained
  61. assert acc >= 0.8
  62. def test_main_csp_offline_pony():
  63. ch_names = [
  64. "T6", "P4", "Pz", "F8", "F4", "Fp1", "Cz", "F7", "F3", "C3", "T3", "Oz",
  65. "O1", "O2", "Fz", "C4", "T4", "Fp2", "T5", "P3"
  66. ]
  67. reader = Reader()
  68. raw = reader.read(PONY_BDF_FILE_PATH, tuple(ch_names))
  69. raw.annotations.rename({
  70. "trainSuccess": "mi",
  71. "trainFailed": "mi",
  72. "restState": "rest"
  73. })
  74. csp_offline = CspOffline()
  75. epochs, _ = csp_offline.get_epochs(raw, tuple(ch_names))
  76. csp = csp_offline.process(epochs)
  77. csp_offline.draw_image(csp, epochs.info, save_path=PONY_CSP_FILE_PATH)
  78. def test_main_csp_offline_neo():
  79. ch_names = ["C3", "FC3", "CP5", "CP1", "C4", "FC4", "CP2", "CP6"]
  80. reader = Reader()
  81. raw = reader.read(NEO_BDF_FILE_PATH, tuple(ch_names))
  82. raw.annotations.rename({
  83. "trainSuccess": "mi",
  84. "trainFailed": "mi",
  85. "restState": "rest"
  86. })
  87. csp_offline = CspOffline()
  88. epochs, _ = csp_offline.get_epochs(raw, tuple(ch_names))
  89. csp = csp_offline.process(epochs)
  90. csp_offline.draw_image(csp, epochs.info, save_path=NEO_CSP_FILE_PATH)