Browse Source

Feat: 将NEO项目内容全部移入kraken

dk 1 year ago
parent
commit
05deaea839

+ 3 - 0
.gitignore

@@ -177,3 +177,6 @@ model/*
 
 backend/data/*
 !backend/data/.gitkeep
+
+backend/tests/data/*
+!backend/tests/data/.gitkeep

+ 0 - 0
backend/dataloaders/__init__.py


+ 60 - 0
backend/dataloaders/library_ieeg.py

@@ -0,0 +1,60 @@
+from scipy import io as sio
+import mne
+import numpy as np
+from .utils import upsample_events
+
+
+# loader for test data
+def raw_preprocessing(data_file, finger_model='cylinder', epoch_time=1., fs=1000):
+    data = sio.loadmat(data_file, simplify_cells=True, squeeze_me=True)
+    # to double
+    raw = data['data'].astype(np.float64).T * 0.0298 * 1e-6  # to V
+    stim_events = data['stim'].astype(np.float64)
+    # deal with line noise
+    raw = mne.filter.notch_filter(raw, fs, [60, 120, 180], trans_bandwidth=3, verbose=False)
+
+    events = extract_events(stim_events, fs)
+    # upsampling 
+    events = upsample_events(events, int(epoch_time * fs))
+
+    info = mne.create_info([f'ch_{i}' for i in range(raw.shape[0])], sfreq=fs, ch_types='ecog')
+
+    # build raw
+    raw = mne.io.RawArray(raw, info)
+
+    annotations = mne.annotations_from_events(events, fs, {1: finger_model, 0: 'rest'})
+    raw.set_annotations(annotations)
+    return raw
+
+
+def extract_events(stim_events, fs=1000.):
+    diff_stim = np.diff(stim_events)
+
+    shift_idx = int(0.5 * fs)  # shift by 500 ms, compensate for reaction time
+
+    # hand only
+    onsets = np.flatnonzero(diff_stim == 12) + shift_idx
+    
+    offsets = np.flatnonzero(diff_stim == -12)
+
+    # handle cut
+    if len(onsets) != len(offsets):
+        # cut first trial
+        if offsets[0] <= onsets[0]:
+            offsets = offsets[1:]
+        # cut last trial
+        else:
+            onsets = onsets[:-1]
+
+    rest_onset = offsets + shift_idx
+    if len(np.unique(offsets - onsets)) > 1:
+        raise ValueError('Unequal trial length?')
+    trial_length = (offsets - onsets)[0]
+
+    # build events
+    events = np.zeros((len(onsets) * 2, 3), dtype=np.int64)
+    events[::2, 0] = onsets
+    events[:, 1] = trial_length
+    events[1::2, 0] = rest_onset
+    events[::2, 2] = 1
+    return events

+ 168 - 0
backend/dataloaders/neo.py

@@ -0,0 +1,168 @@
+import numpy as np
+import os
+import json
+import mne
+import glob
+import pyedflib
+from .utils import upsample_events
+
+
+FINGERMODEL_IDS = {
+    'rest': 0,
+    'cylinder': 1,
+    'ball': 2,
+    'flex': 3,
+    'double': 4,
+    'treble': 5
+}
+
+
+def raw_preprocessing(data_root, session_paths:dict, epoch_time=1., epoch_length=5, rename_event=True):
+    """
+    Params:
+        subj_root: 
+        session_names: dict of lists
+        epoch_time: 
+        epoch_length (int or 'varied')
+        rename_event (True, use unified event label, False use original)
+    """
+    raws_loaded = load_sessions(data_root, session_paths)
+    # process event
+    raws = []
+    for (finger_model, raw) in raws_loaded:
+        fs = raw.info['sfreq']
+        events, _ = mne.events_from_annotations(raw)
+
+        mov_trial_ind = [2, 3]
+        rest_trial_ind = [4]
+        if not rename_event:
+            mov_trial_ind = [finger_model]
+            rest_trial_ind = [0]
+        
+        if isinstance(epoch_length, int):
+            trial_duration = epoch_length
+        elif epoch_length == 'varied':
+            trial_duration = None
+        else:
+            raise ValueError(f'Unsupported epoch_length {epoch_length}')
+        events = reconstruct_events(events, fs, finger_model, 
+                                    mov_trial_ind=mov_trial_ind,
+                                    rest_trial_ind=rest_trial_ind,
+                                    trial_duration=trial_duration, 
+                                    use_original_label=not rename_event)
+        
+        events_upsampled = upsample_events(events, int(fs * epoch_time))
+        annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
+        raw.set_annotations(annotations)
+        raws.append(raw)
+
+    raws = mne.concatenate_raws(raws)
+    raws.load_data()
+
+    # filter 50Hz
+    raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
+
+    return raws
+
+
+def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
+    """重构出事件序列中的单独运动事件
+    Args:
+        fs: int
+        finger_model: 
+    """
+    # Trial duration are fixed to be ? seconds.
+    # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
+    # ignore initialRest
+    # extract trials
+
+    deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
+    deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
+    trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
+    events_new = events[trials_ind_deduplicated]
+    if trial_duration is None:
+        events_new[:-1, 1] = np.diff(events_new[:, 0])
+        events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
+    else:
+        events_new[:, 1] = int(trial_duration * fs)
+    events_final = events_new.copy()
+    if not use_original_label and finger_model is not None:
+        # process mov
+        ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
+        events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model] 
+        # process rest
+        ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
+        events_final[ind_rest, 2] = 0
+    return events_final
+
+
+def load_sessions(data_root, session_names: dict):
+    # return raws for different finger models on an interleaved manner
+    raw_cnt = sum(len(session_names[k]) for k in session_names)
+    raws = []
+    i = 0
+    while i < raw_cnt:
+        for finger_model in session_names.keys():
+            try:
+                s = session_names[finger_model].pop(0)
+                i += 1
+            except IndexError:
+                continue
+            if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
+                # neo format
+                raw = load_neuracle(os.path.join(data_root, s))
+            else:
+                # kraken format
+                data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
+                raw = mne.io.read_raw_bdf(data_file)
+            raws.append((finger_model, raw))
+    return raws  
+
+
+def load_neuracle(data_dir, data_type='ecog'):
+    """
+    neuracle file loader
+    :param 
+        data_dir: root data dir for the experiment
+        sfreq: 
+        data_type: 
+    :return:
+        raw: mne.io.RawArray
+    """
+    f = {
+        'data': os.path.join(data_dir, 'data.bdf'),
+        'evt': os.path.join(data_dir, 'evt.bdf'),
+        'info': os.path.join(data_dir, 'recordInformation.json')
+    }
+    # read json
+    with open(f['info'], 'r') as json_file:
+        record_info = json.load(json_file)
+    start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
+    sfreq = record_info['SampleRate']
+
+    # read data
+    f_data = pyedflib.EdfReader(f['data'])
+    ch_names = f_data.getSignalLabels()
+    data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6
+
+    info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
+    raw = mne.io.RawArray(data, info)
+
+    # read event
+    try:
+        f_evt = pyedflib.EdfReader(f['evt'])
+        onset, duration, content = f_evt.readAnnotations()
+        onset = np.array(onset) - start_time_point * 1e-3  # correct by start time point
+        onset = (onset * sfreq).astype(np.int64)
+
+        duration = (np.array(duration) * sfreq).astype(np.int64) 
+        event_mapping = {c: i for i, c in enumerate(np.unique(content))}
+        event_ids = [event_mapping[i] for i in content]
+        events = np.stack((onset, duration, event_ids), axis=1)
+        
+        annotations = mne.annotations_from_events(events, sfreq)
+        raw.set_annotations(annotations)
+    except OSError:
+        pass
+
+    return raw

+ 40 - 0
backend/dataloaders/utils.py

@@ -0,0 +1,40 @@
+import numpy as np
+import mne
+
+
+def upsample_events(events, upsample_interval=500):
+    # Upsample events every 500 sample points
+    events_new = []
+    for e_ in events:
+        for i in range(0, e_[1] - upsample_interval + 1, upsample_interval):
+            events_new.append([e_[0] + i, 0, e_[-1]])
+    return np.array(events_new)
+
+
+def extend_signal(raw, frequencies, freq_band):
+    """ Extend a signal with filter bank using MNE """
+    raw_ext = np.vstack([
+        bandpass_filter(raw, l_freq=f - freq_band, h_freq=f + freq_band)
+        for f in frequencies]
+    )
+
+    info = mne.create_info(
+        ch_names=sum(
+            list(map(lambda f: [ch + '-' + str(f) + 'Hz'
+                                for ch in raw.ch_names],
+                     frequencies)), []),
+        ch_types=['ecog'] * len(raw.ch_names) * len(frequencies),
+        sfreq=int(raw.info['sfreq'])
+    )
+
+    return mne.io.RawArray(raw_ext, info)
+
+
+def bandpass_filter(raw, l_freq, h_freq, method="iir", verbose=False):
+    """ Band-pass filter a signal using MNE """
+    return raw.copy().filter(
+        l_freq=l_freq,
+        h_freq=h_freq,
+        method=method,
+        verbose=verbose
+    ).get_data()

+ 0 - 163
backend/device/sig_chain/device/faker.py

@@ -1,163 +0,0 @@
-"""接收假数据
-
-Typical usage example:
-
-    connector = FakerConnector()
-    if connector.get_ready():
-        for _ in range(20):
-            connector.receive_wave()
-    connector.stop()
-"""
-import logging
-import numpy as np
-import socket
-
-from device.sig_chain.device.connector_interface import Connector
-from device.sig_chain.device.connector_interface import DataBlockInBuf
-from device.sig_chain.device.connector_interface import Device
-from device.sig_chain.utils import Observable
-from device.sig_chain.utils import Singleton
-
-logger = logging.getLogger(__name__)
-
-
-class SampleParams:
-
-    def __init__(self, channel_count, sample_rate, delay_milliseconds):
-        self.channel_count = channel_count
-        self.channel_labels = [
-            'T6', 'P4', 'Pz', 'M2', 'F8', 'F4', 'Fp1', 'Cz', 'M1', 'F7', 'F3',
-            'C3', 'T3', 'A1', 'Oz', 'O1', 'O2', 'Fz', 'C4', 'T4', 'Fp2', 'A2',
-            'T5', 'P3'
-        ][:self.channel_count]
-        # montage 中定义的通道类型
-        self.channel_types = (['eeg'] * 24)[:self.channel_count]
-        self.sample_rate = sample_rate
-        self.delay_milliseconds = delay_milliseconds
-        self.point_size = 4
-        self.timestamp_size = 8
-        self.data_count_per_channel = int(self.delay_milliseconds *
-                                          self.sample_rate / 1000)
-        self.data_block_size = self.channel_count * self.data_count_per_channel
-        self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size
-        self.physical_max = 20000
-        self.physical_min = -20000
-
-    def refresh(self):
-        self.data_count_per_channel = int(self.delay_milliseconds *
-                                          self.sample_rate / 1000)
-        self.data_block_size = self.channel_count * self.data_count_per_channel
-        self.buffer_size = self.timestamp_size + self.data_block_size * self.point_size
-
-
-class FakerConnector(Connector, Singleton, Observable):
-
-    def __init__(self) -> None:
-        Observable.__init__(self)
-        self.device = Device.FAKER
-        self._host = '127.0.0.1'
-        self._port = 21112
-        self._addr = (self._host, self._port)
-        self._sock = None
-        self._timestamp = 0
-
-        self.sample_params = SampleParams(24, 1000, 250)
-
-        self._is_connected = False
-
-        self.buffer_save = None
-        self.saver = None
-
-    def load_config(self, config_info):
-        if config_info.get('host'):
-            self._host = config_info['host']
-            logger.info('Set host to: %s', self._host)
-        if config_info.get('port'):
-            self._port = config_info['port']
-            logger.info('Set port to: %s', self._port)
-        if config_info.get('channel_count'):
-            self.sample_params.channel_count = config_info['channel_count']
-            logger.info('Set channel count to: %s',
-                        self.sample_params.channel_count)
-        if config_info.get('channel_labels'):
-            assert len( config_info['channel_labels']) == \
-                self.sample_params.channel_count, \
-                'Mismatch of channel labels and channel count'
-            self.sample_params.channel_labels = config_info['channel_labels']
-            logger.info('Set channel labels to: %s',
-                        self.sample_params.channel_labels)
-        if config_info.get('sample_rate'):
-            self.sample_params.sample_rate = config_info['sample_rate']
-            logger.info('Set sample rate to: %s',
-                        self.sample_params.sample_rate)
-        if config_info.get('delay_milliseconds'):
-            self.sample_params.delay_milliseconds = config_info[
-                'delay_milliseconds']
-            logger.info('Set delay milliseconds to: %s',
-                        self.sample_params.delay_milliseconds)
-        # NOTICE: 放在最后执行,以确保更改对buffer生效
-        self._addr = (self._host, self._port)
-        self.sample_params.refresh()
-
-    def is_connected(self):
-        return self._is_connected
-
-    def get_ready(self):
-        self._sock = socket.socket()
-        try:
-            self._sock.connect(self._addr)
-            self._is_connected = True
-            self._sock.sendall(bytes('start', encoding='utf-8'))
-        except ConnectionRefusedError:
-            return False
-        return True
-
-    def setup_wave_mode(self):
-        return True
-
-    def setup_impedance_mode(self):
-        return False
-
-    def receive_wave(self):
-        try:
-            packet = self._sock.recv(self.sample_params.buffer_size)
-            # timestamp = struct.unpack_from("d", packet[:2])
-            packet_parse = np.frombuffer(packet, dtype=np.float32)
-            data_block = packet_parse[2:].reshape(
-                self.sample_params.channel_count,
-                self.sample_params.data_count_per_channel)
-            self._add_a_data_block_to_buffer(data_block)
-            return True
-        except ConnectionAbortedError:
-            return False
-        except IOError:
-            return False
-
-    def receive_impedance(self):
-        raise NotImplementedError
-
-    def _add_a_data_block_to_buffer(self, data_block: np.ndarray):
-        self._timestamp += int(1000 *
-                               self.sample_params.data_count_per_channel /
-                               self.sample_params.sample_rate)
-        data_block_in_buffer = DataBlockInBuf(data_block, self._timestamp)
-        self._save_data_when_buffer_full(data_block_in_buffer)
-        self.notify_observers(data_block_in_buffer)
-
-        return data_block
-
-    def stop(self):
-        if self._sock:
-            self._sock.close()
-        self._is_connected = False
-        self._timestamp = 0
-
-        if self.saver and self.saver.is_ready:
-            self.saver.close_edf_file()
-
-    def notify_observers(self, data_block):
-        for obj in self._observers:
-            obj.update(data_block)
-
-    def restart_wave(self):
-        self._sock.sendall(bytes('restart', encoding='utf-8'))

+ 0 - 0
backend/tests/data/.gitkeep


+ 52 - 0
backend/tests/test_neoloader.py

@@ -0,0 +1,52 @@
+import unittest
+from dataloaders import neo
+import mne
+import numpy as np
+
+
+class TestDataloader(unittest.TestCase):
+    def test_load_sample_data(self):
+        root_path = './tests/data'
+        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
+        event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
+
+        raw = neo.raw_preprocessing(root_path, sessions)
+        events, event_id = mne.events_from_annotations(raw, event_id=event_id)
+        events, events_cnt = np.unique(events[:, -1], return_counts=True)
+        self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))
+
+    def test_load_session(self):
+        root_path = './tests/data'
+        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4']}
+        raws = neo.load_sessions(root_path, sessions)
+        # test if interleaved
+        sess_f = tuple(f for f, r in raws)
+        self.assertEqual(len(raws), 3)
+        self.assertTupleEqual(sess_f, ('cylinder', 'ball', 'cylinder'))
+
+
+    def test_event_parser(self):
+        # fixed length
+        fs = 100
+        test_event = np.array([[0, 0, 4], [100, 0, 4], [600, 0, 3], [700, 0, 3], [1000, 0, 4], [1100, 0, 4]])
+        gt = np.array([[0, 400, 0], [600, 400, 2], [1000, 400, 0]])
+        ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=4)
+        self.assertTrue(np.allclose(ret, gt))
+        # varing length
+        gt = np.array([[0, 600, 0], [600, 400, 2], [1000, 100, 0]])
+        ret = neo.reconstruct_events(test_event, fs, 'ball', trial_duration=None)
+        self.assertTrue(np.allclose(ret, gt))
+        # change indices
+        gt = np.array([[0, 400, 2], [600, 400, 0], [1000, 400, 2]])
+        ret = neo.reconstruct_events(test_event, fs, 'ball', mov_trial_ind=[4], rest_trial_ind=[2, 3], trial_duration=4)
+        self.assertTrue(np.allclose(ret, gt))
+        # use original indices
+        gt = np.array([[0, 400, 4], [600, 400, 3], [1000, 400, 4]])
+        ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[3], rest_trial_ind=[4])
+        self.assertTrue(np.allclose(ret, gt))
+        # use original indices, 
+        gt = np.array([[0, 400, 0], [600, 400, 1], [1000, 400, 0]])
+        test_event = np.array([[0, 0, 0], [100, 0, 0], [600, 0, 1], [700, 0, 1], [1000, 0, 0], [1100, 0, 0]])
+        ret = neo.reconstruct_events(test_event, fs, None, trial_duration=4, use_original_label=True, mov_trial_ind=[1], rest_trial_ind=[0])
+        self.assertTrue(np.allclose(ret, gt))
+

+ 92 - 0
backend/tests/test_online.py

@@ -0,0 +1,92 @@
+import os
+import shutil
+import random
+import bci_core.online as online
+import training
+from dataloaders import library_ieeg
+from validation import DataGenerator
+import unittest
+import numpy as np
+from glob import glob
+
+
+class TestOnline(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        root_path = './tests/data'
+        event_id = {'ball': 2, 'rest': 0}
+
+        raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
+        raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
+        
+        model = training.train_model(raw, event_id, model_type='baseline')
+        
+        training.model_saver(model, root_path, 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
+        cls.model_root = os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66')
+        cls.model_path = glob(os.path.join(root_path, 'f77cbe10a8de473992542e9f4e913a66', '*.pkl'))[0]
+
+        cls.data_gen = DataGenerator(raw.info['sfreq'], raw.get_data())
+    
+    @classmethod
+    def tearDownClass(cls) -> None:
+        shutil.rmtree(cls.model_root)
+        return super().tearDownClass()
+    
+    def test_step_feedback(self):
+        controller = online.Controller(0, self.model_path)
+        rets = []
+        for time, data in self.data_gen.loop():
+            cls = controller.step_decision(data)
+            rets.append(cls)
+        self.assertTrue(np.allclose(np.unique(rets), [0, 2]))
+    
+    def test_virtual_feedback(self):
+        controller = online.Controller(1, None)
+        
+        n_trial = 1000
+        correct = 0
+        for _ in range(n_trial):
+            label = random.randint(0, 1)
+            ret = controller.decision(None, label)
+            if ret == label:
+                correct += 1
+        self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
+
+        correct = 0
+        for _ in range(n_trial):
+            label = random.randint(0, 1)
+            ret = controller.step_decision(None, label)
+            if ret == label:
+                correct += 1
+        self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
+
+    def test_real_feedback(self):
+        controller = online.Controller(0, self.model_path)
+        rets = []
+        for i, (time, data) in zip(range(300), self.data_gen.loop()):
+            cls = controller.decision(data)
+            rets.append(cls)
+        self.assertTrue(np.allclose(np.unique(rets), [-1, 0, 2]))
+
+
+class TestHMM(unittest.TestCase):
+    def test_state_transfer(self):
+        # binary
+        probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
+        true_state = [-1, -1, 1, -1, -1, -1, 0]
+        model = online.HMMModel(2, state_trans_prob=0.9, state_change_threshold=0.7)
+        states = []
+        for p in probs:
+            cur_state = model.update_state(p)
+            states.append(cur_state)
+        self.assertTrue(np.allclose(states, true_state))
+
+        # triple
+        probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
+        true_state = [-1, 1, -1, -1, 0, 2]
+        model = online.HMMModel(3, state_trans_prob=0.9)
+        states = []
+        for p in probs:
+            cur_state = model.update_state(p)
+            states.append(cur_state)
+        self.assertTrue(np.allclose(states, true_state))

+ 56 - 0
backend/tests/test_training.py

@@ -0,0 +1,56 @@
+import os
+import training
+import unittest
+import joblib
+from glob import glob
+from dataloaders import neo
+from bci_core.feature_extractors import FeatExtractor
+from bci_core.model import baseline_model, stacking_riemann, ChannelScaler
+import shutil
+from sklearn.utils.validation import check_is_fitted, NotFittedError
+from sklearn.pipeline import Pipeline
+from sklearn.ensemble import StackingClassifier
+
+
+class TestTraining(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls) -> None:
+        root_path = './tests/data'
+        sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
+        cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
+
+        raw = neo.raw_preprocessing(root_path, sessions, rename_event=True)
+        raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
+        cls.raw = raw
+    
+    def test_training_baseline(self):
+        model = training.train_model(self.raw, self.event_id, model_type='baseline')
+        check_is_fitted(model)
+
+    def test_saver(self):
+        feat_ext = FeatExtractor(1000, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)])
+        model_riemann = stacking_riemann(12, 12, 1, 1)
+        model_baseline = baseline_model(1)
+        scaler = ChannelScaler()
+        event_id = {'1': 5, '0': 3}
+        training.model_saver([feat_ext, scaler, model_riemann, model_baseline], './tests/data', 'baseline', 'f77cbe10a8de473992542e9f4e913a66', event_id)
+        self.assertTrue(os.path.isdir(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66')))
+
+        model_file = glob(os.path.join('./tests/data', 
+                                       'f77cbe10a8de473992542e9f4e913a66', 
+                                       '*.pkl'))
+        
+        self.assertEqual(len(model_file), 1)
+
+        name = model_file[0].split('/')
+        class_name, events, date = name[-1].split('_')
+        self.assertTrue(class_name == 'baseline')
+        self.assertTrue(events == '0+1')
+        # load model
+        feat, scaler, model_riem, model_base = joblib.load(model_file[0])
+        self.assertTrue(isinstance(feat, FeatExtractor))
+        self.assertTrue(isinstance(scaler, ChannelScaler))
+        self.assertTrue(isinstance(model_riem, StackingClassifier))
+        self.assertTrue(isinstance(model_base, Pipeline))
+
+        shutil.rmtree(os.path.join('./tests/data', 'f77cbe10a8de473992542e9f4e913a66'))

+ 53 - 0
backend/tests/test_validation.py

@@ -0,0 +1,53 @@
+import unittest
+import os
+import numpy as np
+
+from bci_core import utils as ana_utils
+from training import train_model
+from dataloaders import library_ieeg
+from validation import validation
+
+
+class TestValidation(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        root_path = './tests/data'
+        cls.event_id = {'ball': 2, 'rest': 0}
+
+        raw = library_ieeg.raw_preprocessing(os.path.join(root_path, 'ecog-data/1', 'bp_mot_t_h.mat'), finger_model='ball')
+        raw = raw.pick_channels([raw.info['ch_names'][i] for i in [5,6,7,12,13,14,20,21]])
+        cls.raw = raw
+        # split into 2 pieces
+        t_min, t_max = raw.times[0], raw.times[-1]
+        t_mid = raw.times[len(raw.times) // 2]
+        raw_train = raw.copy().crop(tmin=t_min, tmax=t_mid, include_tmax=True)
+        cls.raw_val = raw.copy().crop(tmin=t_mid, tmax=t_max)
+
+        # reconstruct single event for validation
+        if cls.raw_val.annotations.onset[0] > t_mid:
+            # correct time by first timestamp
+            cls.raw_val.annotations.onset -= t_mid
+        
+        # train with the first half
+        cls.model = train_model(raw_train, event_id=cls.event_id, model_type='baseline')
+    
+    def test_event_metric(self):
+        event_gt = np.array([[0, 0, 0], [5, 0, 1], [7, 0, 0], [9, 0, 2]])
+        event_pred = np.array([[1, 0, 0], [4, 0, 1], [6, 0, 1], [7, 0, 0], [10, 0, 1], [11, 0, 2]])
+        fs = 1
+        precision, recall, f1_score = ana_utils.event_metric(event_gt, event_pred, fs, ignore_event=(0,))
+        self.assertEqual(f1_score, 2 / 3)
+        self.assertEqual(precision, 1 / 2)
+        self.assertEqual(recall, 1)
+
+    def test_validation(self):
+        (precision, recall, f1_score, r), fig_erds, fig_pred = validation(self.raw, 'baseline', self.event_id, model=self.model, state_change_threshold=0.7)
+        fig_erds.savefig('./tests/data/erds.pdf')
+        fig_pred.savefig('./tests/data/pred.pdf')   
+
+        self.assertTrue(f1_score > 0.9)
+        self.assertTrue(r > 0.5)
+
+
+if __name__ == '__main__':
+    unittest.main()

+ 126 - 0
backend/training.py

@@ -0,0 +1,126 @@
+import logging
+import joblib
+import os
+from datetime import datetime
+from functools import partial
+import yaml
+
+import mne
+import numpy as np
+from scipy import signal
+from pyriemann.estimation import BlockCovariances
+
+import bci_core.feature_extractors as feature_extractors
+import bci_core.utils as bci_utils
+import bci_core.model as bci_model
+from dataloaders import neo
+
+
+def train_model(raw, event_id, model_type='baseline'):
+    """
+    """
+    events, _ = mne.events_from_annotations(raw, event_id=event_id)
+    if model_type.lower() == 'baseline':
+        model = _train_baseline_model(raw, events)
+    elif model_type.lower() == 'riemann':
+        # TODO: load subject config
+        model = _train_riemann_model(raw, events)
+    else:
+        raise NotImplementedError
+    return model
+
+
+def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)]):
+    fs = raw.info['sfreq']
+    n_ch = len(raw.ch_names)
+    feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
+    filtered_data = feat_extractor.transform(raw.get_data())
+    # TODO: find proper latency
+    X = bci_utils.cut_epochs((0, 1., fs), filtered_data, events[:, 0])
+    y = events[:, -1]
+    
+    scaler = bci_model.ChannelScaler()
+    X = scaler.fit_transform(X)
+
+    # compute covariance
+    lfb_dim = len(lfb_bands) * n_ch
+    hgs_dim = len(hg_bands) * n_ch
+    cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
+    X_cov = cov_model.fit_transform(X)
+
+    param = {'C_lfb': np.logspace(-4, 0, 5), 'C_hgs': np.logspace(-3, 1, 5)}
+
+    model_func = partial(bci_model.stacking_riemann, lfb_dim=lfb_dim, hgs_dim=hgs_dim)
+    best_auc, best_param = bci_utils.param_search(model_func, X_cov, y, param)
+
+    logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
+
+    # train and dump best model
+    model_to_train = model_func(**best_param)
+    model_to_train.fit(X, y)
+    return [feat_extractor, scaler, cov_model, model_to_train]
+
+
+def _train_baseline_model(raw, events):
+    fs = raw.info['sfreq']
+    filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
+
+    filter_bank_epoch = bci_utils.cut_epochs((0, 1., fs), filter_bank_data, events[:, 0])
+
+    # downsampling to 10 Hz
+    # decim 2 times, to 100Hz
+    decimate_rate = np.sqrt(fs / 10).astype(np.int16)
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    # to 10Hz
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    X = filter_bank_epoch
+    y = events[:, -1]
+
+    best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
+    logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
+
+    model_to_train = bci_model.baseline_model(**best_param)
+    model_to_train.fit(X, y)
+    return model_to_train
+
+
+def model_saver(model, model_path, model_type, subject_id, event_id):
+    # event list should be sorted by class label
+    sorted_events = sorted(event_id.items(), key=lambda item: item[1])
+    # Extract the keys in the sorted order and store them in a list
+    sorted_events = [item[0] for item in sorted_events]
+
+    try:
+        os.mkdir(os.path.join(model_path, subject_id))
+    except FileExistsError:
+        pass
+
+    now = datetime.now()
+    classes = '+'.join(sorted_events)
+    date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
+    model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
+    joblib.dump(model, os.path.join(model_path, subject_id, model_name))
+
+
+if __name__ == '__main__':
+    subj_name = 'ylj'
+    model_type = 'baseline'
+    # TODO: load subject config
+
+    data_dir = f'./data/{subj_name}/train/'
+    model_dir = './static/models/'
+
+    info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
+    sessions = info['sessions']
+    event_id = {'rest': 0}
+    for f in sessions.keys():
+        event_id[f] = neo.FINGERMODEL_IDS[f]
+    
+    # preprocess raw
+    raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False)
+
+    # train model
+    model = train_model(raw, event_id=event_id, model_type=model_type)
+    
+    # save
+    model_saver(model, model_dir, model_type, subj_name, event_id)

+ 138 - 0
backend/validation.py

@@ -0,0 +1,138 @@
+'''
+模型模拟在线测试脚本
+数据两折分割,1折训练模型,1折按照在线模式测试:decison AUC + event f1-score
+'''
+import numpy as np
+import matplotlib.pyplot as plt
+import mne
+import yaml
+import os
+import joblib
+from scipy import stats
+from dataloaders import neo
+import training
+import bci_core.online as online
+import bci_core.utils as bci_utils
+import bci_core.viz as bci_viz
+
+
+class DataGenerator:
+    def __init__(self, fs, X):
+        self.fs = int(fs)
+        self.X = X
+
+    def get_data_batch(self, current_index):
+        # return 1s batch
+        # create mne object
+        data = self.X[:, current_index - self.fs:current_index].copy()
+        # append event channel
+        data = np.concatenate((data, np.zeros((1, data.shape[1]))), axis=0)
+        info = mne.create_info([f'S{i}' for i in range(len(data))], self.fs, ['ecog'] * (len(data) - 1) + ['misc'])
+        raw = mne.io.RawArray(data, info, verbose=False)
+        return {'data': raw}
+
+    def loop(self, step_size=0.1):
+        step = int(step_size * self.fs)
+        for i in range(self.fs, self.X.shape[1] + 1, step):
+            yield i / self.fs, self.get_data_batch(i)
+
+
+def validation(raw_val, model_type, event_id, model, state_change_threshold=0.8):
+    """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
+    Args:
+        raw (mne.io.Raw)
+        model_type (string): type of model to train, baseline or riemann
+        event_id (dict)
+        model: validate existing model, 
+        state_change_threshold (float): default 0.8
+
+    Returns:
+        None
+    """
+    fs = raw_val.info['sfreq']
+
+    # plot ersd map
+    events, _ = mne.events_from_annotations(raw_val, event_id)
+    fig_erds = bci_viz.plot_ersd(raw_val.get_data(), events, fs, (0, 1), event_id, 0)
+
+    events_val, _ = mne.events_from_annotations(raw_val, event_id)
+        
+    events_val = neo.reconstruct_events(events_val, 
+                                        fs, 
+                                        finger_model=None,
+                                        rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], 
+                                        mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
+                                        use_original_label=True)
+    
+    if model_type == 'baseline':
+        hmm_model = online.BaselineHMM(model, state_change_threshold=state_change_threshold)
+    else:
+        raise NotImplementedError
+    controller = online.Controller(0, None)
+    controller.set_real_feedback_model(hmm_model)
+
+    # validate with the second half
+    val_data = raw_val.get_data()
+    data_gen = DataGenerator(fs, val_data)
+    rets = []
+    for time, data in data_gen.loop():
+        cls = controller.decision(data)
+        rets.append((time, cls))
+    events_pred = _construct_model_event(rets, fs)
+    precision, recall, f_beta_score = bci_utils.event_metric(event_true=events_val, event_pred=events_pred, fs=fs)
+    stim_pred = _event_to_stim_channel(events_pred, len(raw_val.times))
+    stim_true = _event_to_stim_channel(events_val, len(raw_val.times))
+    corr, p = stats.pearsonr(stim_pred, stim_true)
+    fig_pred, ax = plt.subplots(1, 1)
+    ax.plot(raw_val.times, stim_pred, label='pred')
+    ax.plot(raw_val.times, stim_true, label='true')
+    ax.legend()
+
+    return (precision, recall, f_beta_score, corr), fig_erds, fig_pred
+
+
+def _construct_model_event(decision_seq, fs):
+    events = []
+    for i in decision_seq:
+        time, cls = i
+        if cls >= 0:
+            events.append([int(time * fs), 0, cls])
+    return np.array(events)
+
+
+def _event_to_stim_channel(events, time_length):
+    x = np.zeros(time_length)
+    for i in range(0, len(events) - 1):
+        x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
+    return x
+
+
+if __name__ == '__main__':
+    subj_name = 'ylj'
+    model_type = 'baseline'
+    # TODO: load subject config
+
+    data_dir = f'./data/{subj_name}/val/'
+    model_path = f'./static/models/{subj_name}/scis.pkl'
+
+    info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
+    sessions = info['sessions']
+    event_id = {'rest': 0}
+    for f in sessions.keys():
+        event_id[f] = neo.FINGERMODEL_IDS[f]
+    
+    # preprocess raw
+    raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False)
+
+    # load model
+    model = joblib.load(model_path)
+    model_type, events = bci_utils.parse_model_type(model_path)
+
+    metrics, fig_erds, fig_pred = validation(raw, 
+                                             model_type, 
+                                             event_id, 
+                                             model=model, 
+                                             state_change_threshold=0.8)
+    fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
+    fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
+    print(metrics)

+ 1 - 0
environment.yml

@@ -24,3 +24,4 @@ dependencies:
       - scipy~=1.11.3
       - scikit-learn~=1.3.2
       - matplotlib~=3.8.1
+      - pyyaml~=6.0.1

+ 1 - 0
requirements.txt

@@ -17,3 +17,4 @@ SQLAlchemy==2.0.23
 starlette==0.27.0
 streamlit==1.28.1
 opencv_python~=4.8.1
+pyyaml~=6.0.1