import logging import joblib import os from datetime import datetime import yaml import argparse import mne import numpy as np from scipy import signal from pyriemann.estimation import BlockCovariances import bci_core.feature_extractors as feature_extractors import bci_core.utils as bci_utils import bci_core.model as bci_model from dataloaders import neo from settings.config import settings logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--model-type', dest='model_type', default='baseline', type=str ) return parser.parse_args() def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs): """ """ events, _ = mne.events_from_annotations(raw, event_id=event_id) if model_type.lower() == 'baseline': model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs) elif model_type.lower() == 'riemann': model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs) else: raise NotImplementedError return model def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]): fs = raw.info['sfreq'] n_ch = len(raw.ch_names) feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands) filtered_data = feat_extractor.transform(raw.get_data()) X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0]) y = events[:, -1] scaler = bci_model.ChannelScaler() X = scaler.fit_transform(X) # compute covariance lfb_dim = len(lf_bands) * n_ch hgs_dim = len(hg_bands) * n_ch cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf') X_cov = cov_model.fit_transform(X) param = {'C': np.logspace(-5, 4, 10)} best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param) logging.info(f'Best parameter: {best_param}, best auc {best_auc}') # train and dump best model model_to_train = bci_model.riemann_model(**best_param) model_to_train.fit(X_cov, y) return [feat_extractor, scaler, cov_model, model_to_train] def _train_baseline_model(raw, events, duration=1., freqs=(20, 150, 15)): fs = raw.info['sfreq'] freqs = np.arange(*freqs) filterbank_extractor = feature_extractors.FilterbankExtractor(fs, freqs) filter_bank_data = filterbank_extractor.transform(raw.get_data()) filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0]) # downsampling to 10 Hz # decim 2 times, to 100Hz decimate_rate = np.sqrt(fs / 10).astype(np.int16) filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) # to 10Hz filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) X = filter_bank_epoch y = events[:, -1] best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)}) logger.info(f'Best parameter: {best_param}, best auc {best_auc}') model_to_train = bci_model.baseline_model(**best_param) model_to_train.fit(X, y) return filterbank_extractor, model_to_train def model_saver(model, model_path, model_type, subject_id, event_id): # event list should be sorted by class label sorted_events = sorted(event_id.items(), key=lambda item: item[1]) # Extract the keys in the sorted order and store them in a list sorted_events = [item[0] for item in sorted_events] try: os.mkdir(os.path.join(model_path, subject_id)) except FileExistsError: pass now = datetime.now() classes = '+'.join(sorted_events) date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S") model_name = f'{model_type}_{classes}_{date_time_str}.pkl' joblib.dump(model, os.path.join(model_path, subject_id, model_name)) if __name__ == '__main__': args = parse_args() subj_name = args.subj model_type = args.model_type data_dir = f'./data/{subj_name}/' model_dir = './static/models/' with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f: model_config = yaml.safe_load(f)[model_type] with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] trial_duration = config_info['buffer_length'] # preprocess raw raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5) # train model model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config) # save model_saver(model, model_dir, model_type, subj_name, event_id)