|
@@ -125,8 +125,6 @@ def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_tr
|
|
|
|
|
|
|
|
|
def simulation(raw_val, event_id, model,
|
|
|
- state_trans_prob=0.8,
|
|
|
- state_change_threshold=0.8,
|
|
|
epoch_length=1.,
|
|
|
step_length=0.1,
|
|
|
event_trial_length=5.):
|
|
@@ -135,7 +133,6 @@ def simulation(raw_val, event_id, model,
|
|
|
raw (mne.io.Raw)
|
|
|
event_id (dict)
|
|
|
model: validate existing model,
|
|
|
- state_change_threshold (float): default 0.8
|
|
|
epoch_length (float): batch data length, default 1 (s)
|
|
|
step_length (float): data step length, default 0.1 (s)
|
|
|
event_trial_length (float):
|
|
@@ -153,15 +150,11 @@ def simulation(raw_val, event_id, model,
|
|
|
rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
|
|
|
mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
|
|
|
use_original_label=True)
|
|
|
-
|
|
|
- model_hmm = online.model_loader(model,
|
|
|
- state_trans_prob=state_trans_prob,
|
|
|
- state_change_threshold=state_change_threshold)
|
|
|
|
|
|
# run with and without hmm
|
|
|
fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
|
|
|
events_val,
|
|
|
- model_hmm,
|
|
|
+ model,
|
|
|
epoch_length,
|
|
|
step_length,
|
|
|
event_trial_length=event_trial_length)
|
|
@@ -194,7 +187,10 @@ if __name__ == '__main__':
|
|
|
|
|
|
data_dir = f'./data/{subj_name}/'
|
|
|
|
|
|
- model_path = f'./static/models/{subj_name}/{args.model_filename}'
|
|
|
+ model_filename = args.model_filename.split('.')[0]
|
|
|
+ model_path = f'./static/models/{subj_name}/{model_filename}.pkl'
|
|
|
+ transmat_path = f'./static/models/{subj_name}/{model_filename}_transmat.txt'
|
|
|
+
|
|
|
with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
|
|
|
info = yaml.safe_load(f)
|
|
|
sessions = info['sessions']
|
|
@@ -210,13 +206,19 @@ if __name__ == '__main__':
|
|
|
mov_trial_ind=[2],
|
|
|
rest_trial_ind=[1],
|
|
|
upsampled_epoch_length=None)
|
|
|
+
|
|
|
+ # load model
|
|
|
+ input_kwargs = {
|
|
|
+ 'transmat': transmat_path if os.path.isfile(transmat_path) else None,
|
|
|
+ 'state_trans_prob': args.state_trans_prob,
|
|
|
+ 'state_change_threshold': args.state_change_threshold
|
|
|
+ }
|
|
|
+ model_hmm = online.model_loader(model_path, **input_kwargs)
|
|
|
|
|
|
# do validations
|
|
|
metric_hmm, metric_naive, fig_pred = simulation(raw,
|
|
|
event_id,
|
|
|
- model=model_path,
|
|
|
- state_trans_prob=args.state_trans_prob,
|
|
|
- state_change_threshold=args.state_change_threshold,
|
|
|
+ model=model_hmm,
|
|
|
epoch_length=config_info['buffer_length'],
|
|
|
step_length=config_info['buffer_length'],
|
|
|
event_trial_length=trial_time)
|