|
@@ -42,6 +42,14 @@ def parse_args():
|
|
type=float
|
|
type=float
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
|
+ '--state-trans-prob',
|
|
|
|
+ '-stp',
|
|
|
|
+ dest='state_trans_prob',
|
|
|
|
+ help='Transition probability for HMM state change',
|
|
|
|
+ default=0.8,
|
|
|
|
+ type=float
|
|
|
|
+ )
|
|
|
|
+ parser.add_argument(
|
|
'--model-filename',
|
|
'--model-filename',
|
|
dest='model_filename',
|
|
dest='model_filename',
|
|
help='Model filename',
|
|
help='Model filename',
|
|
@@ -70,16 +78,16 @@ class DataGenerator:
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
|
|
|
|
|
|
|
|
-def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
|
|
|
|
|
|
+def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length):
|
|
val_data = raw.get_data()
|
|
val_data = raw.get_data()
|
|
fs = raw.info['sfreq']
|
|
fs = raw.info['sfreq']
|
|
|
|
|
|
- data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
|
|
|
|
|
|
+ data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length)
|
|
|
|
|
|
decision_with_hmm = []
|
|
decision_with_hmm = []
|
|
decision_without_hmm = []
|
|
decision_without_hmm = []
|
|
probs = []
|
|
probs = []
|
|
- for time, data in data_gen.loop():
|
|
|
|
|
|
+ for time, data in data_gen.loop(step_length):
|
|
step_p, cls = model_hmm.viterbi(data, return_step_p=True)
|
|
step_p, cls = model_hmm.viterbi(data, return_step_p=True)
|
|
if cls >=0:
|
|
if cls >=0:
|
|
cls = model_hmm.model.classes_[cls]
|
|
cls = model_hmm.model.classes_[cls]
|
|
@@ -117,16 +125,16 @@ def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
|
|
|
|
|
|
|
|
|
|
def simulation(raw_val, event_id, model,
|
|
def simulation(raw_val, event_id, model,
|
|
- state_change_threshold=0.8,
|
|
|
|
- step_length=1.,
|
|
|
|
|
|
+ epoch_length=1.,
|
|
|
|
+ step_length=0.1,
|
|
event_trial_length=5.):
|
|
event_trial_length=5.):
|
|
"""模型验证接口,使用指定数据进行验证,绘制ersd map
|
|
"""模型验证接口,使用指定数据进行验证,绘制ersd map
|
|
Args:
|
|
Args:
|
|
raw (mne.io.Raw)
|
|
raw (mne.io.Raw)
|
|
event_id (dict)
|
|
event_id (dict)
|
|
model: validate existing model,
|
|
model: validate existing model,
|
|
- state_change_threshold (float): default 0.8
|
|
|
|
- step_length (float): batch data step length, default 1. (s)
|
|
|
|
|
|
+ epoch_length (float): batch data length, default 1 (s)
|
|
|
|
+ step_length (float): data step length, default 0.1 (s)
|
|
event_trial_length (float):
|
|
event_trial_length (float):
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
@@ -142,16 +150,13 @@ def simulation(raw_val, event_id, model,
|
|
rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
|
|
rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'],
|
|
mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
|
|
mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'],
|
|
use_original_label=True)
|
|
use_original_label=True)
|
|
-
|
|
|
|
-
|
|
|
|
- controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
|
|
|
|
- model_hmm = controller.real_feedback_model
|
|
|
|
|
|
|
|
# run with and without hmm
|
|
# run with and without hmm
|
|
fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
|
|
fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val,
|
|
events_val,
|
|
events_val,
|
|
- model_hmm,
|
|
|
|
- step_length,
|
|
|
|
|
|
+ model,
|
|
|
|
+ epoch_length,
|
|
|
|
+ step_length,
|
|
event_trial_length=event_trial_length)
|
|
event_trial_length=event_trial_length)
|
|
|
|
|
|
return metric_hmm, metric_naive, fig_pred
|
|
return metric_hmm, metric_naive, fig_pred
|
|
@@ -183,6 +188,7 @@ if __name__ == '__main__':
|
|
data_dir = f'./data/{subj_name}/'
|
|
data_dir = f'./data/{subj_name}/'
|
|
|
|
|
|
model_path = f'./static/models/{subj_name}/{args.model_filename}'
|
|
model_path = f'./static/models/{subj_name}/{args.model_filename}'
|
|
|
|
+
|
|
with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
|
|
with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
|
|
info = yaml.safe_load(f)
|
|
info = yaml.safe_load(f)
|
|
sessions = info['sessions']
|
|
sessions = info['sessions']
|
|
@@ -192,18 +198,25 @@ if __name__ == '__main__':
|
|
|
|
|
|
# preprocess raw
|
|
# preprocess raw
|
|
trial_time = 5.
|
|
trial_time = 5.
|
|
- raw = neo.raw_preprocessing(data_dir, sessions,
|
|
|
|
|
|
+ raw = neo.raw_loader(data_dir, sessions,
|
|
unify_label=True,
|
|
unify_label=True,
|
|
ori_epoch_length=trial_time,
|
|
ori_epoch_length=trial_time,
|
|
mov_trial_ind=[2],
|
|
mov_trial_ind=[2],
|
|
rest_trial_ind=[1],
|
|
rest_trial_ind=[1],
|
|
upsampled_epoch_length=None)
|
|
upsampled_epoch_length=None)
|
|
|
|
+
|
|
|
|
+ # load model
|
|
|
|
+ input_kwargs = {
|
|
|
|
+ 'state_trans_prob': args.state_trans_prob,
|
|
|
|
+ 'state_change_threshold': args.state_change_threshold
|
|
|
|
+ }
|
|
|
|
+ model_hmm = online.model_loader(model_path, **input_kwargs)
|
|
|
|
|
|
# do validations
|
|
# do validations
|
|
metric_hmm, metric_naive, fig_pred = simulation(raw,
|
|
metric_hmm, metric_naive, fig_pred = simulation(raw,
|
|
event_id,
|
|
event_id,
|
|
- model=model_path,
|
|
|
|
- state_change_threshold=args.state_change_threshold,
|
|
|
|
|
|
+ model=model_hmm,
|
|
|
|
+ epoch_length=config_info['buffer_length'],
|
|
step_length=config_info['buffer_length'],
|
|
step_length=config_info['buffer_length'],
|
|
event_trial_length=trial_time)
|
|
event_trial_length=trial_time)
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|