|
@@ -19,7 +19,8 @@ from settings.config import settings
|
|
|
config_info = settings.CONFIG_INFO
|
|
|
|
|
|
class HMMClassifier(hmm.BaseHMM):
|
|
|
- # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array)
|
|
|
+ # TODO: 如何绕过hmmlearn里使用的sklearn.utils.validation.check_array,目前我直接修改了hmmlearn的源码(删除所有的check_array)
|
|
|
+ # TODO: 可行的方法是修改模型组织,将特征提取步骤与最终分类器分开,模型只保留最终分类器,这样仅需接收二维特征。
|
|
|
def __init__(self, emission_model, **kwargs):
|
|
|
n_components = len(emission_model.classes_)
|
|
|
super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
|
|
@@ -136,7 +137,7 @@ with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
|
|
|
info = yaml.safe_load(f)
|
|
|
sessions = info['hmm_sessions']
|
|
|
|
|
|
-raw = neo.raw_loader(data_dir, sessions, True)
|
|
|
+raw, event_id = neo.raw_loader(data_dir, sessions, True)
|
|
|
|
|
|
# cut into buffer len epochs
|
|
|
if model_type == 'baseline':
|