import numpy as np import mne from mne import create_info from core.mi.pipeline import BaselineModel class DataGenerator: def __init__(self, fs, X, info): self.fs = int(fs) self.X = X self.info = info def get_data_batch(self, current_index): # return 1s batch # create mne object data = self.X[:, current_index - self.fs:current_index].copy() # append event channel data = np.concatenate((data, np.zeros((1, data.shape[1]))), axis=0) info = create_info([f'S{i}' for i in range(len(data))], self.info['sfreq'], ['ecog'] * (len(data) - 1) + ['misc']) raw = mne.io.RawArray(data, info, verbose=False) return {'data': raw} def loop(self): # 0.1s step step = int(0.1 * self.fs) for i in range(self.fs, self.X.shape[1] + 1, step): yield i / self.fs, self.get_data_batch(i) def test_pipeline(): data = mne.io.read_raw("core/mi/raw_eeg.fif") X = data.get_data() info = data.info.copy() gen = DataGenerator(info["sfreq"], X, info) pipeline = BaselineModel("core/mi/bp-baseline.pkl") for t, batch_data in gen.loop(): print(pipeline.smoothed_decision(batch_data))