Parcourir la source

add main function

dk il y a 1 an
Parent
commit
3d9e9645ff
1 fichiers modifiés avec 33 ajouts et 31 suppressions
  1. 33 31
      backend/train_hmm.py

+ 33 - 31
backend/train_hmm.py

@@ -120,43 +120,45 @@ def parse_args():
     )
     return parser.parse_args()
 
-args = parse_args()
-# load model and fit hmm
-subj_name = args.subj
-model_filename = args.model_filename
 
-data_dir = f'./data/{subj_name}/'
-    
-model_path = f'./static/models/{subj_name}/{model_filename}'
+if __name__ == '__main__':
+    args = parse_args()
+    # load model and fit hmm
+    subj_name = args.subj
+    model_filename = args.model_filename
+
+    data_dir = f'./data/{subj_name}/'
+        
+    model_path = f'./static/models/{subj_name}/{model_filename}'
 
-# load model
-model_type, _ = bci_utils.parse_model_type(model_filename)
-model = joblib.load(model_path)
+    # load model
+    model_type, _ = bci_utils.parse_model_type(model_filename)
+    model = joblib.load(model_path)
 
-with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
-    info = yaml.safe_load(f)
-sessions = info['hmm_sessions']
+    with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
+        info = yaml.safe_load(f)
+    sessions = info['hmm_sessions']
 
-raw, event_id = neo.raw_loader(data_dir, sessions, True)
+    raw, event_id = neo.raw_loader(data_dir, sessions, True)
 
-# cut into buffer len epochs
-if model_type == 'baseline':
-    feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
-elif model_type == 'riemann':
-    feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
-else:
-    raise ValueError
+    # cut into buffer len epochs
+    if model_type == 'baseline':
+        feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
+    elif model_type == 'riemann':
+        feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
+    else:
+        raise ValueError
 
-# initiate hmm model
-hmm_model = HMMClassifier(model[-1], n_iter=100)
-hmm_model.fit(feature)
+    # initiate hmm model
+    hmm_model = HMMClassifier(model[-1], n_iter=100)
+    hmm_model.fit(feature)
 
-# decode
-log_probs, state_seqs = hmm_model.decode(feature)
-plt.figure()
-plt.plot(state_seqs)
+    # decode
+    log_probs, state_seqs = hmm_model.decode(feature)
+    plt.figure()
+    plt.plot(state_seqs)
 
-# save transmat
-np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
+    # save transmat
+    np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
 
-plt.show()
+    plt.show()