Quellcode durchsuchen

修正:遗漏的新api修正

dk vor 1 Jahr
Ursprung
Commit
bb0dbc488e
1 geänderte Dateien mit 3 neuen und 2 gelöschten Zeilen
  1. 3 2
      backend/train_hmm.py

+ 3 - 2
backend/train_hmm.py

@@ -19,7 +19,8 @@ from settings.config import settings
 config_info = settings.CONFIG_INFO
 
 class HMMClassifier(hmm.BaseHMM):
-    # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array)
+    # TODO: 如何绕过hmmlearn里使用的sklearn.utils.validation.check_array,目前我直接修改了hmmlearn的源码(删除所有的check_array)
+    # TODO: 可行的方法是修改模型组织,将特征提取步骤与最终分类器分开,模型只保留最终分类器,这样仅需接收二维特征。
     def __init__(self, emission_model, **kwargs):
         n_components = len(emission_model.classes_)
         super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
@@ -136,7 +137,7 @@ with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
     info = yaml.safe_load(f)
 sessions = info['hmm_sessions']
 
-raw = neo.raw_loader(data_dir, sessions, True)
+raw, event_id = neo.raw_loader(data_dir, sessions, True)
 
 # cut into buffer len epochs
 if model_type == 'baseline':