|
@@ -2,7 +2,6 @@ import logging
|
|
|
import joblib
|
|
|
import os
|
|
|
from datetime import datetime
|
|
|
-from functools import partial
|
|
|
import yaml
|
|
|
|
|
|
import mne
|
|
@@ -19,27 +18,27 @@ from dataloaders import neo
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
|
-def train_model(raw, event_id, model_type='baseline'):
|
|
|
+def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
|
|
|
"""
|
|
|
"""
|
|
|
events, _ = mne.events_from_annotations(raw, event_id=event_id)
|
|
|
if model_type.lower() == 'baseline':
|
|
|
- model = _train_baseline_model(raw, events)
|
|
|
+ model = _train_baseline_model(raw, events, duration=trial_duration)
|
|
|
elif model_type.lower() == 'riemann':
|
|
|
# TODO: load subject config
|
|
|
- model = _train_riemann_model(raw, events)
|
|
|
+ model = _train_riemann_model(raw, events, duration=trial_duration)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
return model
|
|
|
|
|
|
|
|
|
-def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)]):
|
|
|
+def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
|
|
|
fs = raw.info['sfreq']
|
|
|
n_ch = len(raw.ch_names)
|
|
|
feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
|
|
|
filtered_data = feat_extractor.transform(raw.get_data())
|
|
|
# TODO: find proper latency
|
|
|
- X = bci_utils.cut_epochs((0, 1., fs), filtered_data, events[:, 0])
|
|
|
+ X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
|
|
|
y = events[:, -1]
|
|
|
|
|
|
scaler = bci_model.ChannelScaler()
|
|
@@ -51,24 +50,23 @@ def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[
|
|
|
cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
|
|
|
X_cov = cov_model.fit_transform(X)
|
|
|
|
|
|
- param = {'C_lfb': np.logspace(-4, 0, 5), 'C_hgs': np.logspace(-3, 1, 5)}
|
|
|
+ param = {'C': np.logspace(-5, 4, 10)}
|
|
|
|
|
|
- model_func = partial(bci_model.stacking_riemann, lfb_dim=lfb_dim, hgs_dim=hgs_dim)
|
|
|
- best_auc, best_param = bci_utils.param_search(model_func, X_cov, y, param)
|
|
|
+ best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
|
|
|
|
|
|
logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
|
|
|
|
|
|
# train and dump best model
|
|
|
- model_to_train = model_func(**best_param)
|
|
|
+ model_to_train = bci_model.riemann_model(**best_param)
|
|
|
model_to_train.fit(X_cov, y)
|
|
|
return [feat_extractor, scaler, cov_model, model_to_train]
|
|
|
|
|
|
|
|
|
-def _train_baseline_model(raw, events):
|
|
|
+def _train_baseline_model(raw, events, duration=1., ):
|
|
|
fs = raw.info['sfreq']
|
|
|
filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
|
|
|
|
|
|
- filter_bank_epoch = bci_utils.cut_epochs((0, 1., fs), filter_bank_data, events[:, 0])
|
|
|
+ filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
|
|
|
|
|
|
# downsampling to 10 Hz
|
|
|
# decim 2 times, to 100Hz
|
|
@@ -108,7 +106,7 @@ def model_saver(model, model_path, model_type, subject_id, event_id):
|
|
|
if __name__ == '__main__':
|
|
|
# TODO: argparse
|
|
|
subj_name = 'XW01'
|
|
|
- model_type = 'baseline'
|
|
|
+ model_type = 'riemann'
|
|
|
# TODO: load subject config
|
|
|
|
|
|
data_dir = f'./data/{subj_name}/'
|