123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- """ CSP 单元测试 """
- # pylint: disable=missing-class-docstring
- import os
- import numpy as np
- from core.sig_chain.sig_reader import Reader
- from core.mi.eeg_csp import CspOffline
- from core.mi.eeg_csp import CSPBasedClassifier
- TEST_DATA_PATH = "tests/data/"
- PONY_BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "eeg_raw_data.bdf")
- NEO_BDF_FILE_PATH = os.path.join(TEST_DATA_PATH, "neo_eeg_raw_data.bdf")
- PONY_CSP_FILE_PATH = os.path.join(TEST_DATA_PATH, "pony_csp.png")
- NEO_CSP_FILE_PATH = os.path.join(TEST_DATA_PATH, "neo_csp.png")
- def setup_module():
- if not os.path.exists(TEST_DATA_PATH):
- os.makedirs(TEST_DATA_PATH)
- def teardown_module():
- if os.path.exists(PONY_CSP_FILE_PATH):
- os.remove(PONY_CSP_FILE_PATH)
- if os.path.exists(NEO_CSP_FILE_PATH):
- os.remove(NEO_CSP_FILE_PATH)
- class TestCSPBasedClassifier():
- @classmethod
- def setup_class(cls):
- ch_names = [
- "T6", "P4", "Pz", "M2", "F8", "F4", "Fp1", "Cz", "M1", "F7", "F3",
- "C3", "T3", "A1", "Oz", "O1", "O2", "Fz", "C4", "T4", "Fp2", "A2",
- "T5", "P3"
- ]
- reader = Reader()
- cls.raw = reader.read(PONY_BDF_FILE_PATH, tuple(ch_names))
- cls.raw.annotations.rename({
- "trainSuccess": "mi",
- "trainFailed": "mi",
- "restState": "rest"
- })
- cls.raw.annotations.duration += 0.999
- def setup_method(self):
- self.clf = CSPBasedClassifier()
- def test_train_and_test_with_same_sample_length(self):
- ch_names = ["C3", "Cz", "C4"]
- csp_offline = CspOffline()
- epochs, _ = csp_offline.get_epochs(self.raw, tuple(ch_names))
- labels = epochs.events[:, -1]
- self.clf.fit(epochs.get_data(), labels)
- predicts = self.clf.predict(epochs.get_data())
- acc = np.sum(predicts == labels) / predicts.size
- assert self.clf.is_trained
- assert acc >= 0.8
- def test_train_and_test_with_different_sample_length(self):
- ch_names = ["C3", "Cz", "C4"]
- csp_offline1 = CspOffline()
- epochs1, _ = csp_offline1.get_epochs(self.raw, tuple(ch_names))
- csp_offline3 = CspOffline()
- csp_offline3.tmax = 3
- epochs3, _ = csp_offline3.get_epochs(self.raw, tuple(ch_names))
- labels = epochs1.events[:, -1]
- self.clf.fit(epochs3.get_data(), labels)
- predicts = self.clf.predict(epochs1.get_data())
- acc = np.sum(predicts == labels) / predicts.size
- assert self.clf.is_trained
- assert acc >= 0.8
- def test_main_csp_offline_pony():
- ch_names = [
- "T6", "P4", "Pz", "F8", "F4", "Fp1", "Cz", "F7", "F3", "C3", "T3", "Oz",
- "O1", "O2", "Fz", "C4", "T4", "Fp2", "T5", "P3"
- ]
- reader = Reader()
- raw = reader.read(PONY_BDF_FILE_PATH, tuple(ch_names))
- raw.annotations.rename({
- "trainSuccess": "mi",
- "trainFailed": "mi",
- "restState": "rest"
- })
- csp_offline = CspOffline()
- epochs, _ = csp_offline.get_epochs(raw, tuple(ch_names))
- csp = csp_offline.process(epochs)
- csp_offline.draw_image(csp, epochs.info, save_path=PONY_CSP_FILE_PATH)
- def test_main_csp_offline_neo():
- ch_names = ["C3", "FC3", "CP5", "CP1", "C4", "FC4", "CP2", "CP6"]
- reader = Reader()
- raw = reader.read(NEO_BDF_FILE_PATH, tuple(ch_names))
- raw.annotations.rename({
- "trainSuccess": "mi",
- "trainFailed": "mi",
- "restState": "rest"
- })
- csp_offline = CspOffline()
- epochs, _ = csp_offline.get_epochs(raw, tuple(ch_names))
- csp = csp_offline.process(epochs)
- csp_offline.draw_image(csp, epochs.info, save_path=NEO_CSP_FILE_PATH)
|