|
@@ -3,6 +3,7 @@ import numpy as np
|
|
|
import random
|
|
|
import logging
|
|
|
from scipy import signal
|
|
|
+import mne
|
|
|
from .feature_extractors import filterbank_extractor
|
|
|
from .utils import parse_model_type
|
|
|
|
|
@@ -132,7 +133,10 @@ class HMMModel:
|
|
|
self._probability[current_state] = 1
|
|
|
|
|
|
def step_probability(self, fs, data):
|
|
|
- raise NotImplementedError
|
|
|
+ # do preprocessing here
|
|
|
+ # common average
|
|
|
+ data -= data.mean(axis=0)
|
|
|
+ return data
|
|
|
|
|
|
def parse_data(self, data):
|
|
|
fs, event, data_array = data
|
|
@@ -181,6 +185,7 @@ class BaselineHMM(HMMModel):
|
|
|
def step_probability(self, fs, data):
|
|
|
"""Step
|
|
|
"""
|
|
|
+ data = super(BaselineHMM, self).step_probability(fs, data)
|
|
|
# filter data
|
|
|
filter_bank_data = filterbank_extractor(data, fs, self.freqs, reshape_freqs_dim=True)
|
|
|
# downsampling
|
|
@@ -203,6 +208,7 @@ class RiemannHMM(HMMModel):
|
|
|
def step_probability(self, fs, data):
|
|
|
"""Step
|
|
|
"""
|
|
|
+ data = super(RiemannHMM, self).step_probability(fs, data)
|
|
|
data = self.feat_extractor.transform(data)
|
|
|
# scale data
|
|
|
data = self.scaler.transform(data)
|