|
@@ -53,12 +53,12 @@ class HMMClassifier(hmm.BaseHMM):
|
|
return p
|
|
return p
|
|
|
|
|
|
|
|
|
|
-def extract_baseline_feature(model, raw, step):
|
|
|
|
|
|
+def extract_baseline_feature(model, raw, step=0.1, buffer_length=0.5):
|
|
fs = raw.info['sfreq']
|
|
fs = raw.info['sfreq']
|
|
feat_extractor, _ = model
|
|
feat_extractor, _ = model
|
|
filter_bank_data = feat_extractor.transform(raw.get_data())
|
|
filter_bank_data = feat_extractor.transform(raw.get_data())
|
|
- timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
|
|
|
|
- filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
|
|
|
|
|
|
+ timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
|
|
|
|
+ filter_bank_epoch = bci_utils.cut_epochs((0, buffer_length, fs), filter_bank_data, timestamps)
|
|
# decimate
|
|
# decimate
|
|
decimate_rate = np.sqrt(fs / 10).astype(np.int16)
|
|
decimate_rate = np.sqrt(fs / 10).astype(np.int16)
|
|
filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
|
|
filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
|
|
@@ -67,20 +67,20 @@ def extract_baseline_feature(model, raw, step):
|
|
return filter_bank_epoch
|
|
return filter_bank_epoch
|
|
|
|
|
|
|
|
|
|
-def extract_riemann_feature(model, raw, step):
|
|
|
|
|
|
+def extract_riemann_feature(model, raw, step=0.1, buffer_length=0.5):
|
|
fs = raw.info['sfreq']
|
|
fs = raw.info['sfreq']
|
|
feat_extractor, scaler, cov_model, _ = model
|
|
feat_extractor, scaler, cov_model, _ = model
|
|
filtered_data = feat_extractor.transform(raw.get_data())
|
|
filtered_data = feat_extractor.transform(raw.get_data())
|
|
- timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
|
|
|
|
- X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
|
|
|
|
|
|
+ timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
|
|
|
|
+ X = bci_utils.cut_epochs((0, buffer_length, fs), filtered_data, timestamps)
|
|
X = scaler.transform(X)
|
|
X = scaler.transform(X)
|
|
X_cov = cov_model.transform(X)
|
|
X_cov = cov_model.transform(X)
|
|
return X_cov
|
|
return X_cov
|
|
|
|
|
|
|
|
|
|
-def _split_continuous(time_range, step, fs):
|
|
|
|
|
|
+def _split_continuous(time_range, step, fs, window_size):
|
|
return np.arange(int(time_range[0] * fs),
|
|
return np.arange(int(time_range[0] * fs),
|
|
- int(time_range[-1] * fs),
|
|
|
|
|
|
+ int(time_range[-1] * fs) - int(window_size * fs),
|
|
int(step * fs), dtype=np.int64)
|
|
int(step * fs), dtype=np.int64)
|
|
|
|
|
|
|
|
|
|
@@ -143,9 +143,9 @@ if __name__ == '__main__':
|
|
|
|
|
|
# cut into buffer len epochs
|
|
# cut into buffer len epochs
|
|
if model_type == 'baseline':
|
|
if model_type == 'baseline':
|
|
- feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
|
|
|
|
|
|
+ feature = extract_baseline_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
|
|
elif model_type == 'riemann':
|
|
elif model_type == 'riemann':
|
|
- feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
|
|
|
|
|
|
+ feature = extract_riemann_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
|
|
else:
|
|
else:
|
|
raise ValueError
|
|
raise ValueError
|
|
|
|
|