瀏覽代碼

HMM训练,step和window length分别处理

dk 1 年之前
父節點
當前提交
1203d25b6e
共有 1 個文件被更改,包括 10 次插入10 次删除
  1. 10 10
      backend/train_hmm.py

+ 10 - 10
backend/train_hmm.py

@@ -53,12 +53,12 @@ class HMMClassifier(hmm.BaseHMM):
         return p
 
 
-def extract_baseline_feature(model, raw, step):
+def extract_baseline_feature(model, raw, step=0.1, buffer_length=0.5):
     fs = raw.info['sfreq']
     feat_extractor, _ = model
     filter_bank_data = feat_extractor.transform(raw.get_data())
-    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
-    filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
+    filter_bank_epoch = bci_utils.cut_epochs((0, buffer_length, fs), filter_bank_data, timestamps)
     # decimate
     decimate_rate = np.sqrt(fs / 10).astype(np.int16)
     filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
@@ -67,20 +67,20 @@ def extract_baseline_feature(model, raw, step):
     return filter_bank_epoch
 
 
-def extract_riemann_feature(model, raw, step):
+def extract_riemann_feature(model, raw, step=0.1, buffer_length=0.5):
     fs = raw.info['sfreq']
     feat_extractor, scaler, cov_model, _ = model
     filtered_data = feat_extractor.transform(raw.get_data())
-    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
-    X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
+    X = bci_utils.cut_epochs((0, buffer_length, fs), filtered_data, timestamps)
     X = scaler.transform(X)
     X_cov = cov_model.transform(X)
     return X_cov
 
 
-def _split_continuous(time_range, step, fs):
+def _split_continuous(time_range, step, fs, window_size):
     return np.arange(int(time_range[0] * fs), 
-                           int(time_range[-1] * fs), 
+                           int(time_range[-1] * fs) - int(window_size * fs), 
                            int(step * fs), dtype=np.int64)
 
 
@@ -143,9 +143,9 @@ if __name__ == '__main__':
 
     # cut into buffer len epochs
     if model_type == 'baseline':
-        feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
+        feature = extract_baseline_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
     elif model_type == 'riemann':
-        feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
+        feature = extract_riemann_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
     else:
         raise ValueError