import logging import os import yaml import argparse import mne import numpy as np from sklearn.linear_model import LogisticRegression import bci_core.pipeline as bci_pipeline import bci_core.utils as bci_utils from dataloaders import neo from settings.config import settings logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--model-type', dest='model_type', default='baseline', type=str ) return parser.parse_args() def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs): """ """ events, _ = mne.events_from_annotations(raw, event_id=event_id) fs = raw.info['sfreq'] n_ch = len(raw.ch_names) if model_type.lower() == 'baseline': feat_extractor, embedder = bci_pipeline.baseline_model_builder(fs=fs, target_fs=10, **model_kwargs) elif model_type.lower() == 'riemann': feat_extractor, embedder = bci_pipeline.riemann_model_builder(fs=fs, n_ch=n_ch, **model_kwargs) elif model_type.lower() == 'csp': feat_extractor, embedder = bci_pipeline.csp_model_builder(fs=fs, **model_kwargs) else: raise NotImplementedError classifier = _param_search([feat_extractor, embedder], raw, trial_duration, events) return [feat_extractor, embedder, classifier] def _param_search(model, raw, duration, events): fs = raw.info['sfreq'] feat_extractor, embedder = model filtered_data = feat_extractor.transform(raw.get_data()) X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0]) y = events[:, -1] # embed feature X_embed = embedder.fit_transform(X, y) param = {'C': np.logspace(-5, 4, 10)} best_auc, best_param = bci_utils.param_search(LogisticRegression, X_embed, y, param) logging.info(f'Best parameter: {best_param}, best auc {best_auc}') # train and dump best model model_for_train = LogisticRegression(**best_param) model_for_train.fit(X_embed, y) return model_for_train if __name__ == '__main__': args = parse_args() subj_name = args.subj model_type = args.model_type data_dir = os.path.join(settings.DATA_PATH, subj_name) model_dir = settings.MODEL_PATH with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f: model_config = yaml.safe_load(f)[model_type] with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] upsampled_trial_duration = config_info['buffer_length'] ori_epoch_length = info.get('ori_epoch_length', 5.) # preprocess raw raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=upsampled_trial_duration, ori_epoch_length=ori_epoch_length, reref_method=config_info['reref']) # train model model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=upsampled_trial_duration, **model_config) # save bci_utils.model_saver(model, model_dir, model_type, subj_name, event_id)