""" CSP 单元测试 """ # pylint: disable=missing-class-docstring import os import numpy as np from core.sig_chain.sig_reader import Reader from core.mi.eeg_csp import CspOffline from core.mi.eeg_csp import CSPBasedClassifier TEST_DATA_PATH = "tests/data/" PONY_BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "eeg_raw_data.bdf") NEO_BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "neo_eeg_raw_data.bdf") PONY_CSP_FILE_PATH = os.path.join(TEST_DATA_PATH, "pony_csp.png") NEO_CSP_FILE_PATH = os.path.join(TEST_DATA_PATH, "neo_csp.png") def setup_module(): if not os.path.exists(TEST_DATA_PATH): os.makedirs(TEST_DATA_PATH) def teardown_module(): if os.path.exists(PONY_CSP_FILE_PATH): os.remove(PONY_CSP_FILE_PATH) if os.path.exists(NEO_CSP_FILE_PATH): os.remove(NEO_CSP_FILE_PATH) class TestCSPBasedClassifier(): @classmethod def setup_class(cls): ch_names = [ "T6", "P4", "Pz", "M2", "F8", "F4", "Fp1", "Cz", "M1", "F7", "F3", "C3", "T3", "A1", "Oz", "O1", "O2", "Fz", "C4", "T4", "Fp2", "A2", "T5", "P3" ] reader = Reader() cls.raw = reader.read(PONY_BDF_FILE_PATH, tuple(ch_names)) cls.raw.annotations.rename({ "trainSuccess": "mi", "trainFailed": "mi", "restState": "rest" }) cls.raw.annotations.duration += 0.999 def setup_method(self): self.clf = CSPBasedClassifier() def test_train_and_test_with_same_sample_length(self): ch_names = ["C3", "Cz", "C4"] csp_offline = CspOffline() epochs, _ = csp_offline.get_epochs(self.raw, tuple(ch_names)) labels = epochs.events[:, -1] self.clf.fit(epochs.get_data(), labels) predicts = self.clf.predict(epochs.get_data()) acc = np.sum(predicts == labels) / predicts.size assert self.clf.is_trained assert acc >= 0.8 def test_train_and_test_with_different_sample_length(self): ch_names = ["C3", "Cz", "C4"] csp_offline1 = CspOffline() epochs1, _ = csp_offline1.get_epochs(self.raw, tuple(ch_names)) csp_offline3 = CspOffline() csp_offline3.tmax = 3 epochs3, _ = csp_offline3.get_epochs(self.raw, tuple(ch_names)) labels = epochs1.events[:, -1] self.clf.fit(epochs3.get_data(), labels) predicts = self.clf.predict(epochs1.get_data()) acc = np.sum(predicts == labels) / predicts.size assert self.clf.is_trained assert acc >= 0.8 def test_main_csp_offline_pony(): ch_names = [ "T6", "P4", "Pz", "F8", "F4", "Fp1", "Cz", "F7", "F3", "C3", "T3", "Oz", "O1", "O2", "Fz", "C4", "T4", "Fp2", "T5", "P3" ] reader = Reader() raw = reader.read(PONY_BDF_FILE_PATH, tuple(ch_names)) raw.annotations.rename({ "trainSuccess": "mi", "trainFailed": "mi", "restState": "rest" }) csp_offline = CspOffline() epochs, _ = csp_offline.get_epochs(raw, tuple(ch_names)) csp = csp_offline.process(epochs) csp_offline.draw_image(csp, epochs.info, save_path=PONY_CSP_FILE_PATH) def test_main_csp_offline_neo(): ch_names = ["C3", "FC3", "CP5", "CP1", "C4", "FC4", "CP2", "CP6"] reader = Reader() raw = reader.read(NEO_BDF_FILE_PATH, tuple(ch_names)) raw.annotations.rename({ "trainSuccess": "mi", "trainFailed": "mi", "restState": "rest" }) csp_offline = CspOffline() epochs, _ = csp_offline.get_epochs(raw, tuple(ch_names)) csp = csp_offline.process(epochs) csp_offline.draw_image(csp, epochs.info, save_path=NEO_CSP_FILE_PATH)