dk 1 vuosi sitten
vanhempi
commit
7b7e9c25bd
2 muutettua tiedostoa jossa 3 lisäystä ja 2 poistoa
  1. 1 1
      backend/dataloaders/neo.py
  2. 2 1
      backend/train_hmm.py

+ 1 - 1
backend/dataloaders/neo.py

@@ -23,7 +23,7 @@ def raw_loader(data_root, session_paths:dict,
         session_paths: dict of lists
         reref_method (str): rereference method: monopolar, average, or bipolar
         upsampled_epoch_length (None or float): None: do not do upsampling
-        ori_epoch_length (int or 'varied'): original epoch length in second
+        ori_epoch_length (int, dict, or 'varied'): original epoch length in second
     """
     raws_loaded = load_sessions(data_root, session_paths, reref_method)
     # process event

+ 2 - 1
backend/train_hmm.py

@@ -139,7 +139,7 @@ if __name__ == '__main__':
         info = yaml.safe_load(f)
     sessions = info['hmm_sessions']
 
-    raw, event_id = neo.raw_loader(data_dir, sessions, True)
+    raw, event_id = neo.raw_loader(data_dir, sessions, config_info['reref'])
 
     # cut into buffer len epochs
     if model_type == 'baseline':
@@ -150,6 +150,7 @@ if __name__ == '__main__':
         raise ValueError
 
     # initiate hmm model
+    # TODO: building transmat init
     hmm_model = HMMClassifier(model[-1], n_iter=100)
     hmm_model.fit(feature)