123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- """ core/mi/eeg_psd.py 单元测试 """
- # pylint: disable=missing-class-docstring
- import os
- import pytest
- import numpy as np
- from core.sig_chain.sig_reader import Reader
- from core.mi.eeg_psd import PSDBasedClassifier
- from core.mi.eeg_psd import Psd
- from tests.utils.core import get_epochs
- TEST_DATA_PATH = 'tests/data/'
- BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, 'eeg_raw_data.bdf')
- PONY_PSD_FILE_PATH = os.path.join(TEST_DATA_PATH, 'pony_psd.png')
- def setup_module():
- if not os.path.exists(TEST_DATA_PATH):
- os.makedirs(TEST_DATA_PATH)
- def teardown_module():
- if os.path.exists(PONY_PSD_FILE_PATH):
- os.remove(PONY_PSD_FILE_PATH)
- class TestPSDBasedClassifier():
- @classmethod
- def setup_class(cls):
- ch_names = [
- 'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3',
- 'C3', 'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2',
- 'T5', 'P3'
- ]
- reader = Reader()
- cls.raw = reader.read(BDF_FILE_PATH, tuple(ch_names))
- cls.raw.annotations.duration += 0.999
- def setup_method(self):
- self.clf = PSDBasedClassifier()
- def generate_one_sample(self, channel_count, high_freq=10, low_freq=0.4):
- # 生成信号:10Hz的正弦波 + 0.4Hz的正弦波
- tt = np.linspace(0, 1, 1000, endpoint=False)
- xx = np.sin(2 * np.pi * high_freq * tt) + np.sin(
- 2 * np.pi * low_freq * tt)
- return np.stack([xx for ch in range(channel_count)])
- def test_psd_feature_extract_return_correct_shape(self):
- sample = self.generate_one_sample(1)
- bp_sample = self.clf.psd_feature_extract(sample)
- assert isinstance(bp_sample, float)
- # assert (1,) == bp_sample.shape
- def test_psd_feature_extract_get_higher_value_for_matched_signal(self):
- sample_match = self.generate_one_sample(1)
- sample_not_match = self.generate_one_sample(1, high_freq=30)
- bp_match = self.clf.psd_feature_extract(sample_match)
- bp_not_match = self.clf.psd_feature_extract(sample_not_match)
- assert bp_match > bp_not_match
- def test_fit_with_single_channel_data(self):
- ch_names = ['C4']
- epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999)
- train_success = self.clf.fit(epochs.get_data())
- assert train_success
- def test_fit_with_multi_channel_data(self):
- ch_names = ['C3', 'C4']
- epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999)
- train_success = self.clf.fit(epochs.get_data())
- assert train_success
- def test_predict_before_fit_note_allowed(self):
- sample = self.generate_one_sample(1)
- with pytest.raises(Exception):
- self.clf.predict(sample[np.newaxis, :])
- def test_predict_with_single_channel_data(self):
- channel_count = 1
- self.clf.is_trained = True
- sample = self.generate_one_sample(channel_count)
- pred = self.clf.predict(sample[np.newaxis, :])
- assert pred in [0, 1]
- def test_predict_with_multi_channel_data(self):
- channel_count = 2
- self.clf.is_trained = True
- sample = self.generate_one_sample(channel_count)
- pred = self.clf.predict(sample[np.newaxis, :])
- assert pred in [0, 1]
- def test_main(self):
- ch_names = ['C4']
- epochs = get_epochs(self.raw, tuple(ch_names), 'restState', tmax=0.999)
- train_success = self.clf.fit(epochs.get_data())
- predicts = self.clf.predict(epochs.get_data())
- acc = np.sum(predicts == 0) / predicts.size
- assert train_success
- assert acc >= self.clf.acc_accepted
- epochs_mi = get_epochs(self.raw,
- tuple(ch_names),
- 'trainSuccess',
- tmax=0.999)
- predicts = self.clf.predict(epochs_mi.get_data())
- acc = np.sum(predicts == 1) / predicts.size
- assert acc >= self.clf.acc_accepted
- def test_pony():
- bdf_file_path = os.path.join(TEST_DATA_PATH, '5_3_right_hand.bdf')
- ch_names = ['C3', 'C4']
- reader = Reader()
- raw = reader.read(bdf_file_path, tuple(ch_names))
- reader.fix_annotation(raw)
- psd = Psd(0.1, 40, 0, 3)
- epochs = psd.get_epochs(raw, tuple(ch_names))
- psd.draw_image(epochs, ch_names, PONY_PSD_FILE_PATH)
|