test_riemannian.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import numpy as np
  2. import mne
  3. from mne import create_info
  4. from core.mi.pipeline import BaselineModel
  5. class DataGenerator:
  6. def __init__(self, fs, X, info):
  7. self.fs = int(fs)
  8. self.X = X
  9. self.info = info
  10. def get_data_batch(self, current_index):
  11. # return 1s batch
  12. # create mne object
  13. data = self.X[:, current_index - self.fs:current_index].copy()
  14. # append event channel
  15. data = np.concatenate((data, np.zeros((1, data.shape[1]))), axis=0)
  16. info = create_info([f'S{i}' for i in range(len(data))], self.info['sfreq'], ['ecog'] * (len(data) - 1) + ['misc'])
  17. raw = mne.io.RawArray(data, info, verbose=False)
  18. return {'data': raw}
  19. def loop(self):
  20. # 0.1s step
  21. step = int(0.1 * self.fs)
  22. for i in range(self.fs, self.X.shape[1] + 1, step):
  23. yield i / self.fs, self.get_data_batch(i)
  24. def test_pipeline():
  25. data = mne.io.read_raw("core/mi/raw_eeg.fif")
  26. X = data.get_data()
  27. info = data.info.copy()
  28. gen = DataGenerator(info["sfreq"], X, info)
  29. pipeline = BaselineModel("core/mi/bp-baseline.pkl")
  30. for t, batch_data in gen.loop():
  31. print(pipeline.smoothed_decision(batch_data))