Преглед на файлове

Feat: 考虑反馈条件下的不同trial长度

dk преди 1 година
родител
ревизия
f3dc9588c9
променени са 2 файла, в които са добавени 30 реда и са изтрити 13 реда
  1. 4 4
      backend/dataloaders/neo.py
  2. 26 9
      backend/validation.py

+ 4 - 4
backend/dataloaders/neo.py

@@ -26,7 +26,7 @@ def raw_preprocessing(data_root, session_paths:dict,
         subj_root: 
         session_paths: dict of lists
         do_rereference (bool): do common average rereference or not
-        upsampled_epoch_length: 
+        upsampled_epoch_length (None or float): None: do not do upsampling
         ori_epoch_length (int or 'varied'): original epoch length in second
         unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
         mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
@@ -54,9 +54,9 @@ def raw_preprocessing(data_root, session_paths:dict,
                                     rest_trial_ind=rest_trial_ind,
                                     trial_duration=trial_duration, 
                                     use_original_label=not unify_label)
-        
-        events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
-        annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
+        if upsampled_epoch_length is not None:
+            events = upsample_events(events, int(fs * upsampled_epoch_length))
+        annotations = mne.annotations_from_events(events, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
         raw.set_annotations(annotations)
         raws.append(raw)
 

+ 26 - 9
backend/validation.py

@@ -71,7 +71,7 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
 
 
-def _evaluation_loop(raw, events, model_hmm, step_length):
+def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
     val_data = raw.get_data()
     fs = raw.info['sfreq']
 
@@ -96,7 +96,7 @@ def _evaluation_loop(raw, events, model_hmm, step_length):
 
     p_n, r_n, f1_n = bci_utils.event_metric(events, events_pred_naive, fs=fs)
 
-    stim_true = _event_to_stim_channel(events, len(raw.times))
+    stim_true = _event_to_stim_channel(events, len(raw.times), trial_length=int(event_trial_length * fs))
     stim_pred = _event_to_stim_channel(events_pred, len(raw.times))
     stim_pred_naive = _event_to_stim_channel(events_pred_naive, len(raw.times))
 
@@ -118,7 +118,10 @@ def _evaluation_loop(raw, events, model_hmm, step_length):
     return fig_pred, (p_hmm, r_hmm, f1_hmm, accu_hmm), (p_n, r_n, f1_n, accu_naive)
 
 
-def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
+def validation(raw_val, event_id, model, 
+               state_change_threshold=0.8, 
+               step_length=1., 
+               event_trial_length=5.):
     """模型验证接口,使用指定数据进行验证,绘制ersd map
     Args:
         raw (mne.io.Raw)
@@ -126,6 +129,7 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
         model: validate existing model, 
         state_change_threshold (float): default 0.8
         step_length (float): batch data step length, default 1. (s)
+        event_trial_length (float): 
 
     Returns:
         None
@@ -149,7 +153,11 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
     model_hmm = controller.real_feedback_model
 
     # run with and without hmm
-    fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, events_val, model_hmm, step_length)
+    fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
+                                                          events_val, 
+                                                          model_hmm, 
+                                                          step_length, 
+                                                          event_trial_length=event_trial_length)
 
     return metric_hmm, metric_naive, fig_erds, fig_pred
 
@@ -163,10 +171,13 @@ def _construct_model_event(decision_seq, fs):
     return np.array(events)
 
 
-def _event_to_stim_channel(events, time_length):
+def _event_to_stim_channel(events, time_length, trial_length=None):
     x = np.zeros(time_length)
     for i in range(0, len(events) - 1):
-        x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
+        if trial_length is not None:
+            x[events[i, 0]: events[i, 0] + trial_length] = events[i, 2]
+        else:
+            x[events[i, 0]: events[i + 1, 0] - 1] = events[i, 2]
     return x
 
 
@@ -185,15 +196,21 @@ if __name__ == '__main__':
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     # preprocess raw
-    trial_duration = config_info['buffer_length']
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1], upsampled_epoch_length=trial_duration)
+    trial_time = 5.
+    raw = neo.raw_preprocessing(data_dir, sessions, 
+                                unify_label=True, 
+                                ori_epoch_length=trial_time, 
+                                mov_trial_ind=[2], 
+                                rest_trial_ind=[1], 
+                                upsampled_epoch_length=None)
 
     # do validations
     metric_hmm, metric_naive, fig_erds, fig_pred = validation(raw, 
                                              event_id, 
                                              model=model_path, 
                                              state_change_threshold=args.state_change_threshold,
-                                             step_length=config_info['buffer_length'])
+                                             step_length=config_info['buffer_length'],
+                                             event_trial_length=trial_time)
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
     logger.info(f'With HMM: precision: {metric_hmm[0]:.4f}, recall: {metric_hmm[1]:.4f}, f1_score: {metric_hmm[2]:.4f}, accuracy: {metric_hmm[3]:.4f}')