import joblib import numpy as np from scipy import signal from core.mi.feature_extractors import filterbank_extractor class BaselineModel: def __init__(self, model_path, buffer_steps=5): self.model = joblib.load(model_path) self._freqs = np.arange(20, 150, 15) self.buffer_steps = buffer_steps self.buffer = [] def reset_buffer(self): self.buffer = [] def step_probability(self, fs, data): # TODO: make sure if scaling is needed # data *= 0.0298 * 1e-6 # filter data filter_bank_data = filterbank_extractor(data, fs, self._freqs, reshape_freqs_dim=True) # downsampling filter_bank_data = signal.decimate(filter_bank_data, 10, axis=-1, zero_phase=True) filter_bank_data = signal.decimate(filter_bank_data, 10, axis=-1, zero_phase=True) # predict proba p = self.model.predict_proba(filter_bank_data[None]).squeeze() return p[1] def _parse_data(self, data): data = data['data'] fs = data.info['sfreq'] data = data.get_data() # drop last event channel data = data[:-1] return fs, data def smoothed_decision(self, data): """ Interface for class decision """ fs, data = self._parse_data(data) p = self.step_probability(fs, data) self.buffer.append(p) if len(self.buffer) > self.buffer_steps: self.buffer.pop(0) aveg_p = np.mean(self.buffer) return int(aveg_p > 0.9)