pipeline.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import joblib
  2. import numpy as np
  3. from scipy import signal
  4. from core.mi.feature_extractors import filterbank_extractor
  5. class BaselineModel:
  6. def __init__(self, model_path, buffer_steps=5):
  7. self.model = joblib.load(model_path)
  8. self._freqs = np.arange(20, 150, 15)
  9. self.buffer_steps = buffer_steps
  10. self.buffer = []
  11. def reset_buffer(self):
  12. self.buffer = []
  13. def step_probability(self, fs, data):
  14. # TODO: make sure if scaling is needed
  15. # data *= 0.0298 * 1e-6
  16. # filter data
  17. filter_bank_data = filterbank_extractor(data, fs, self._freqs, reshape_freqs_dim=True)
  18. # downsampling
  19. filter_bank_data = signal.decimate(filter_bank_data, 10, axis=-1, zero_phase=True)
  20. filter_bank_data = signal.decimate(filter_bank_data, 10, axis=-1, zero_phase=True)
  21. # predict proba
  22. p = self.model.predict_proba(filter_bank_data[None]).squeeze()
  23. return p[1]
  24. def _parse_data(self, data):
  25. data = data['data']
  26. fs = data.info['sfreq']
  27. data = data.get_data()
  28. # drop last event channel
  29. data = data[:-1]
  30. return fs, data
  31. def smoothed_decision(self, data):
  32. """
  33. Interface for class decision
  34. """
  35. fs, data = self._parse_data(data)
  36. p = self.step_probability(fs, data)
  37. self.buffer.append(p)
  38. if len(self.buffer) > self.buffer_steps:
  39. self.buffer.pop(0)
  40. aveg_p = np.mean(self.buffer)
  41. return int(aveg_p > 0.9)