""" core/mi/eeg_psd.py 单元测试 """ # pylint: disable=missing-class-docstring import os import pytest import numpy as np from core.sig_chain.sig_reader import Reader from core.mi.eeg_psd import PSDBasedClassifier from core.mi.eeg_psd import Psd from tests.utils.core import get_epochs TEST_DATA_PATH = 'tests/data/' BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, 'eeg_raw_data.bdf') PONY_PSD_FILE_PATH = os.path.join(TEST_DATA_PATH, 'pony_psd.png') def setup_module(): if not os.path.exists(TEST_DATA_PATH): os.makedirs(TEST_DATA_PATH) def teardown_module(): if os.path.exists(PONY_PSD_FILE_PATH): os.remove(PONY_PSD_FILE_PATH) class TestPSDBasedClassifier(): @classmethod def setup_class(cls): ch_names = [ 'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3', 'C3', 'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2', 'T5', 'P3' ] reader = Reader() cls.raw = reader.read(BDF_FILE_PATH, tuple(ch_names)) cls.raw.annotations.duration += 0.999 def setup_method(self): self.clf = PSDBasedClassifier() def generate_one_sample(self, channel_count, high_freq=10, low_freq=0.4): # 生成信号:10Hz的正弦波 + 0.4Hz的正弦波 tt = np.linspace(0, 1, 1000, endpoint=False) xx = np.sin(2 * np.pi * high_freq * tt) + np.sin( 2 * np.pi * low_freq * tt) return np.stack([xx for ch in range(channel_count)]) def test_psd_feature_extract_return_correct_shape(self): sample = self.generate_one_sample(1) bp_sample = self.clf.psd_feature_extract(sample) assert isinstance(bp_sample, float) # assert (1,) == bp_sample.shape def test_psd_feature_extract_get_higher_value_for_matched_signal(self): sample_match = self.generate_one_sample(1) sample_not_match = self.generate_one_sample(1, high_freq=30) bp_match = self.clf.psd_feature_extract(sample_match) bp_not_match = self.clf.psd_feature_extract(sample_not_match) assert bp_match > bp_not_match def test_fit_with_single_channel_data(self): ch_names = ['C4'] epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999) train_success = self.clf.fit(epochs.get_data()) assert train_success def test_fit_with_multi_channel_data(self): ch_names = ['C3', 'C4'] epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999) train_success = self.clf.fit(epochs.get_data()) assert train_success def test_predict_before_fit_note_allowed(self): sample = self.generate_one_sample(1) with pytest.raises(Exception): self.clf.predict(sample[np.newaxis, :]) def test_predict_with_single_channel_data(self): channel_count = 1 self.clf.is_trained = True sample = self.generate_one_sample(channel_count) pred = self.clf.predict(sample[np.newaxis, :]) assert pred in [0, 1] def test_predict_with_multi_channel_data(self): channel_count = 2 self.clf.is_trained = True sample = self.generate_one_sample(channel_count) pred = self.clf.predict(sample[np.newaxis, :]) assert pred in [0, 1] def test_main(self): ch_names = ['C4'] epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999) train_success = self.clf.fit(epochs.get_data()) predicts = self.clf.predict(epochs.get_data()) acc = np.sum(predicts == 0) / predicts.size assert train_success assert acc >= self.clf.acc_accepted epochs_mi = get_epochs(self.raw, tuple(ch_names), 'trainSuccess', tmax=0.999) predicts = self.clf.predict(epochs_mi.get_data()) acc = np.sum(predicts == 1) / predicts.size assert acc >= self.clf.acc_accepted def test_pony(): bdf_file_path = os.path.join(TEST_DATA_PATH, '5_3_right_hand.bdf') ch_names = ['C3', 'C4'] reader = Reader() raw = reader.read(bdf_file_path, tuple(ch_names)) reader.fix_annotation(raw) psd = Psd(0.1, 40, 0, 3) epochs = psd.get_epochs(raw, tuple(ch_names)) psd.draw_image(epochs, ch_names, PONY_PSD_FILE_PATH)