|
@@ -120,43 +120,45 @@ def parse_args():
|
|
)
|
|
)
|
|
return parser.parse_args()
|
|
return parser.parse_args()
|
|
|
|
|
|
-args = parse_args()
|
|
|
|
-# load model and fit hmm
|
|
|
|
-subj_name = args.subj
|
|
|
|
-model_filename = args.model_filename
|
|
|
|
|
|
|
|
-data_dir = f'./data/{subj_name}/'
|
|
|
|
-
|
|
|
|
-model_path = f'./static/models/{subj_name}/{model_filename}'
|
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ args = parse_args()
|
|
|
|
+ # load model and fit hmm
|
|
|
|
+ subj_name = args.subj
|
|
|
|
+ model_filename = args.model_filename
|
|
|
|
+
|
|
|
|
+ data_dir = f'./data/{subj_name}/'
|
|
|
|
+
|
|
|
|
+ model_path = f'./static/models/{subj_name}/{model_filename}'
|
|
|
|
|
|
-# load model
|
|
|
|
-model_type, _ = bci_utils.parse_model_type(model_filename)
|
|
|
|
-model = joblib.load(model_path)
|
|
|
|
|
|
+ # load model
|
|
|
|
+ model_type, _ = bci_utils.parse_model_type(model_filename)
|
|
|
|
+ model = joblib.load(model_path)
|
|
|
|
|
|
-with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
|
|
|
|
- info = yaml.safe_load(f)
|
|
|
|
-sessions = info['hmm_sessions']
|
|
|
|
|
|
+ with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
|
|
|
|
+ info = yaml.safe_load(f)
|
|
|
|
+ sessions = info['hmm_sessions']
|
|
|
|
|
|
-raw, event_id = neo.raw_loader(data_dir, sessions, True)
|
|
|
|
|
|
+ raw, event_id = neo.raw_loader(data_dir, sessions, True)
|
|
|
|
|
|
-# cut into buffer len epochs
|
|
|
|
-if model_type == 'baseline':
|
|
|
|
- feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
|
|
|
|
-elif model_type == 'riemann':
|
|
|
|
- feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
|
|
|
|
-else:
|
|
|
|
- raise ValueError
|
|
|
|
|
|
+ # cut into buffer len epochs
|
|
|
|
+ if model_type == 'baseline':
|
|
|
|
+ feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
|
|
|
|
+ elif model_type == 'riemann':
|
|
|
|
+ feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError
|
|
|
|
|
|
-# initiate hmm model
|
|
|
|
-hmm_model = HMMClassifier(model[-1], n_iter=100)
|
|
|
|
-hmm_model.fit(feature)
|
|
|
|
|
|
+ # initiate hmm model
|
|
|
|
+ hmm_model = HMMClassifier(model[-1], n_iter=100)
|
|
|
|
+ hmm_model.fit(feature)
|
|
|
|
|
|
-# decode
|
|
|
|
-log_probs, state_seqs = hmm_model.decode(feature)
|
|
|
|
-plt.figure()
|
|
|
|
-plt.plot(state_seqs)
|
|
|
|
|
|
+ # decode
|
|
|
|
+ log_probs, state_seqs = hmm_model.decode(feature)
|
|
|
|
+ plt.figure()
|
|
|
|
+ plt.plot(state_seqs)
|
|
|
|
|
|
-# save transmat
|
|
|
|
-np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
|
|
|
|
|
|
+ # save transmat
|
|
|
|
+ np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
|
|
|
|
|
|
-plt.show()
|
|
|
|
|
|
+ plt.show()
|