12345678910111213141516171819202122232425262728293031323334353637383940 |
- import numpy as np
- import mne
- from mne import create_info
- from core.mi.pipeline import BaselineModel
- class DataGenerator:
- def __init__(self, fs, X, info):
- self.fs = int(fs)
- self.X = X
- self.info = info
- def get_data_batch(self, current_index):
-
-
- data = self.X[:, current_index - self.fs:current_index].copy()
-
- data = np.concatenate((data, np.zeros((1, data.shape[1]))), axis=0)
- info = create_info([f'S{i}' for i in range(len(data))], self.info['sfreq'], ['ecog'] * (len(data) - 1) + ['misc'])
- raw = mne.io.RawArray(data, info, verbose=False)
- return {'data': raw}
- def loop(self):
-
- step = int(0.1 * self.fs)
- for i in range(self.fs, self.X.shape[1] + 1, step):
- yield i / self.fs, self.get_data_batch(i)
- def test_pipeline():
- data = mne.io.read_raw("core/mi/raw_eeg.fif")
- X = data.get_data()
- info = data.info.copy()
- gen = DataGenerator(info["sfreq"], X, info)
- pipeline = BaselineModel("core/mi/bp-baseline.pkl")
- for t, batch_data in gen.loop():
- print(pipeline.smoothed_decision(batch_data))
-
|