123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import logging
- import os
- import yaml
- import argparse
- import mne
- import numpy as np
- from sklearn.linear_model import LogisticRegression
- import bci_core.pipeline as bci_pipeline
- import bci_core.utils as bci_utils
- from dataloaders import neo
- from settings.config import settings
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--model-type',
- dest='model_type',
- default='baseline',
- type=str
- )
- return parser.parse_args()
- def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
- """
- """
- events, _ = mne.events_from_annotations(raw, event_id=event_id)
- fs = raw.info['sfreq']
- n_ch = len(raw.ch_names)
- if model_type.lower() == 'baseline':
- feat_extractor, embedder = bci_pipeline.baseline_model_builder(fs=fs, target_fs=10, **model_kwargs)
- elif model_type.lower() == 'riemann':
- feat_extractor, embedder = bci_pipeline.riemann_model_builder(fs=fs, n_ch=n_ch, **model_kwargs)
- elif model_type.lower() == 'csp':
- feat_extractor, embedder = bci_pipeline.csp_model_builder(fs=fs, **model_kwargs)
- else:
- raise NotImplementedError
-
- classifier = _param_search([feat_extractor, embedder], raw, trial_duration, events)
- return [feat_extractor, embedder, classifier]
- def _param_search(model, raw, duration, events):
- fs = raw.info['sfreq']
- feat_extractor, embedder = model
- filtered_data = feat_extractor.transform(raw.get_data())
- X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
- y = events[:, -1]
- # embed feature
- X_embed = embedder.fit_transform(X, y)
- param = {'C': np.logspace(-5, 4, 10)}
- best_auc, best_param = bci_utils.param_search(LogisticRegression, X_embed, y, param)
- logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
- # train and dump best model
- model_for_train = LogisticRegression(**best_param)
- model_for_train.fit(X_embed, y)
- return model_for_train
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- model_type = args.model_type
- data_dir = os.path.join(settings.DATA_PATH, subj_name)
- model_dir = settings.MODEL_PATH
- with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
- model_config = yaml.safe_load(f)[model_type]
- with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- upsampled_trial_duration = config_info['buffer_length']
- ori_epoch_length = info.get('ori_epoch_length', 5.)
-
- # preprocess raw
- raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=upsampled_trial_duration, ori_epoch_length=ori_epoch_length, reref_method=config_info['reref'])
- # train model
- model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=upsampled_trial_duration, **model_config)
-
- # save
- bci_utils.model_saver(model, model_dir, model_type, subj_name, event_id)
|