1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import joblib
- import numpy as np
- from scipy import signal
- from core.mi.feature_extractors import filterbank_extractor
- class BaselineModel:
- def __init__(self, model_path, buffer_steps=5):
- self.model = joblib.load(model_path)
- self._freqs = np.arange(20, 150, 15)
- self.buffer_steps = buffer_steps
- self.buffer = []
-
- def reset_buffer(self):
- self.buffer = []
-
- def step_probability(self, fs, data):
- # TODO: make sure if scaling is needed
- # data *= 0.0298 * 1e-6
- # filter data
- filter_bank_data = filterbank_extractor(data, fs, self._freqs, reshape_freqs_dim=True)
- # downsampling
- filter_bank_data = signal.decimate(filter_bank_data, 10, axis=-1, zero_phase=True)
- filter_bank_data = signal.decimate(filter_bank_data, 10, axis=-1, zero_phase=True)
- # predict proba
- p = self.model.predict_proba(filter_bank_data[None]).squeeze()
- return p[1]
-
- def _parse_data(self, data):
- data = data['data']
- fs = data.info['sfreq']
- data = data.get_data()
- # drop last event channel
- data = data[:-1]
- return fs, data
-
- def smoothed_decision(self, data):
- """
- Interface for class decision
- """
- fs, data = self._parse_data(data)
- p = self.step_probability(fs, data)
- self.buffer.append(p)
- if len(self.buffer) > self.buffer_steps:
- self.buffer.pop(0)
- aveg_p = np.mean(self.buffer)
- return int(aveg_p > 0.9)
|