|
@@ -177,10 +177,8 @@ class HMMModel:
|
|
|
class BaselineHMM(HMMModel):
|
|
|
def __init__(self, model, **kwargs):
|
|
|
if isinstance(model, str):
|
|
|
- self.model = joblib.load(model)
|
|
|
- else:
|
|
|
- self.model = model
|
|
|
- self.freqs = np.arange(20, 150, 15)
|
|
|
+ model = joblib.load(model)
|
|
|
+ self.feat_extractor, self.model = model
|
|
|
|
|
|
super(BaselineHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
|
|
|
|
|
@@ -189,7 +187,7 @@ class BaselineHMM(HMMModel):
|
|
|
"""
|
|
|
data = super(BaselineHMM, self).step_probability(fs, data)
|
|
|
# filter data
|
|
|
- filter_bank_data = filterbank_extractor(data, fs, self.freqs, reshape_freqs_dim=True)
|
|
|
+ filter_bank_data = self.feat_extractor.transform(data)
|
|
|
# downsampling
|
|
|
decimate_rate = np.sqrt(fs / 10).astype(np.int16)
|
|
|
filter_bank_data = signal.decimate(filter_bank_data, decimate_rate, axis=-1, zero_phase=True)
|
|
@@ -202,9 +200,10 @@ class BaselineHMM(HMMModel):
|
|
|
class RiemannHMM(HMMModel):
|
|
|
def __init__(self, model, **kwargs):
|
|
|
if isinstance(model, str):
|
|
|
- self.feat_extractor, self.scaler, self.cov, self.model = joblib.load(model)
|
|
|
- else:
|
|
|
- self.feat_extractor, self.scaler, self.cov, self.model = model
|
|
|
+ model = joblib.load(model)
|
|
|
+
|
|
|
+ self.feat_extractor, self.scaler, self.cov, self.model = model
|
|
|
+
|
|
|
super(RiemannHMM, self).__init__(n_classes=len(self.model.classes_), **kwargs)
|
|
|
|
|
|
def step_probability(self, fs, data):
|