test_psd.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. """ core/mi/eeg_psd.py 单元测试 """
  2. # pylint: disable=missing-class-docstring
  3. import os
  4. import pytest
  5. import numpy as np
  6. from core.sig_chain.sig_reader import Reader
  7. from core.mi.eeg_psd import PSDBasedClassifier
  8. from core.mi.eeg_psd import Psd
  9. from tests.utils.core import get_epochs
  10. TEST_DATA_PATH = 'tests/data/'
  11. BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, 'eeg_raw_data.bdf')
  12. PONY_PSD_FILE_PATH = os.path.join(TEST_DATA_PATH, 'pony_psd.png')
  13. def setup_module():
  14. if not os.path.exists(TEST_DATA_PATH):
  15. os.makedirs(TEST_DATA_PATH)
  16. def teardown_module():
  17. if os.path.exists(PONY_PSD_FILE_PATH):
  18. os.remove(PONY_PSD_FILE_PATH)
  19. class TestPSDBasedClassifier():
  20. @classmethod
  21. def setup_class(cls):
  22. ch_names = [
  23. 'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3',
  24. 'C3', 'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2',
  25. 'T5', 'P3'
  26. ]
  27. reader = Reader()
  28. cls.raw = reader.read(BDF_FILE_PATH, tuple(ch_names))
  29. cls.raw.annotations.duration += 0.999
  30. def setup_method(self):
  31. self.clf = PSDBasedClassifier()
  32. def generate_one_sample(self, channel_count, high_freq=10, low_freq=0.4):
  33. # 生成信号:10Hz的正弦波 + 0.4Hz的正弦波
  34. tt = np.linspace(0, 1, 1000, endpoint=False)
  35. xx = np.sin(2 * np.pi * high_freq * tt) + np.sin(
  36. 2 * np.pi * low_freq * tt)
  37. return np.stack([xx for ch in range(channel_count)])
  38. def test_psd_feature_extract_return_correct_shape(self):
  39. sample = self.generate_one_sample(1)
  40. bp_sample = self.clf.psd_feature_extract(sample)
  41. assert isinstance(bp_sample, float)
  42. # assert (1,) == bp_sample.shape
  43. def test_psd_feature_extract_get_higher_value_for_matched_signal(self):
  44. sample_match = self.generate_one_sample(1)
  45. sample_not_match = self.generate_one_sample(1, high_freq=30)
  46. bp_match = self.clf.psd_feature_extract(sample_match)
  47. bp_not_match = self.clf.psd_feature_extract(sample_not_match)
  48. assert bp_match > bp_not_match
  49. def test_fit_with_single_channel_data(self):
  50. ch_names = ['C4']
  51. epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999)
  52. train_success = self.clf.fit(epochs.get_data())
  53. assert train_success
  54. def test_fit_with_multi_channel_data(self):
  55. ch_names = ['C3', 'C4']
  56. epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999)
  57. train_success = self.clf.fit(epochs.get_data())
  58. assert train_success
  59. def test_predict_before_fit_note_allowed(self):
  60. sample = self.generate_one_sample(1)
  61. with pytest.raises(Exception):
  62. self.clf.predict(sample[np.newaxis, :])
  63. def test_predict_with_single_channel_data(self):
  64. channel_count = 1
  65. self.clf.is_trained = True
  66. sample = self.generate_one_sample(channel_count)
  67. pred = self.clf.predict(sample[np.newaxis, :])
  68. assert pred in [0, 1]
  69. def test_predict_with_multi_channel_data(self):
  70. channel_count = 2
  71. self.clf.is_trained = True
  72. sample = self.generate_one_sample(channel_count)
  73. pred = self.clf.predict(sample[np.newaxis, :])
  74. assert pred in [0, 1]
  75. def test_main(self):
  76. ch_names = ['C4']
  77. epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999)
  78. train_success = self.clf.fit(epochs.get_data())
  79. predicts = self.clf.predict(epochs.get_data())
  80. acc = np.sum(predicts == 0) / predicts.size
  81. assert train_success
  82. assert acc >= self.clf.acc_accepted
  83. epochs_mi = get_epochs(self.raw,
  84. tuple(ch_names),
  85. 'trainSuccess',
  86. tmax=0.999)
  87. predicts = self.clf.predict(epochs_mi.get_data())
  88. acc = np.sum(predicts == 1) / predicts.size
  89. assert acc >= self.clf.acc_accepted
  90. def test_pony():
  91. bdf_file_path = os.path.join(TEST_DATA_PATH, '5_3_right_hand.bdf')
  92. ch_names = ['C3', 'C4']
  93. reader = Reader()
  94. raw = reader.read(bdf_file_path, tuple(ch_names))
  95. reader.fix_annotation(raw)
  96. psd = Psd(0.1, 40, 0, 3)
  97. epochs = psd.get_epochs(raw, tuple(ch_names))
  98. psd.draw_image(epochs, ch_names, PONY_PSD_FILE_PATH)