Browse Source

Feat: 增加trial duration输入参数

dk 1 year ago
parent
commit
29e32b88fc
2 changed files with 12 additions and 7 deletions
  1. 4 2
      backend/training.py
  2. 8 5
      backend/validation.py

+ 4 - 2
backend/training.py

@@ -108,6 +108,7 @@ if __name__ == '__main__':
     subj_name = 'XW01'
     model_type = 'riemann'
     # TODO: load subject config
+    # include frequency band, model_type, upsampled_trial_duration
 
     data_dir = f'./data/{subj_name}/'
     model_dir = './static/models/'
@@ -119,11 +120,12 @@ if __name__ == '__main__':
     for f in sessions.keys():
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
+    trial_duration = 0.5
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
     # train model
-    model = train_model(raw, event_id=event_id, model_type=model_type)
+    model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration)
     
     # save
     model_saver(model, model_dir, model_type, subj_name, event_id)

+ 8 - 5
backend/validation.py

@@ -19,14 +19,16 @@ logging.basicConfig(level=logging.DEBUG)
 
 
 class DataGenerator:
-    def __init__(self, fs, X):
+    def __init__(self, fs, X, epoch_step=1.):
         self.fs = int(fs)
         self.X = X
+        self.epoch_step = epoch_step
 
     def get_data_batch(self, current_index):
-        # return 1s batch
+        # return epoch_step length batch
         # create mne object
-        data = self.X[:, current_index - self.fs:current_index].copy()
+        ind = int(self.epoch_step * self.fs)
+        data = self.X[:, current_index - ind:current_index].copy()
         return self.fs, [], data
 
     def loop(self, step_size=0.1):
@@ -35,13 +37,14 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
 
 
-def validation(raw_val, event_id, model, state_change_threshold=0.8):
+def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
     """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
     Args:
         raw (mne.io.Raw)
         event_id (dict)
         model: validate existing model, 
         state_change_threshold (float): default 0.8
+        step_length (float): batch data step length, default 1. (s)
 
     Returns:
         None
@@ -65,7 +68,7 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8):
 
     # validate with the second half
     val_data = raw_val.get_data()
-    data_gen = DataGenerator(fs, val_data)
+    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
     rets = []
     for time, data in data_gen.loop():
         cls = controller.decision(data)