import logging import joblib import os from datetime import datetime from functools import partial import yaml import mne import numpy as np from scipy import signal from pyriemann.estimation import BlockCovariances import bci_core.feature_extractors as feature_extractors import bci_core.utils as bci_utils import bci_core.model as bci_model from dataloaders import neo logging.basicConfig(level=logging.INFO) def train_model(raw, event_id, model_type='baseline'): """ """ events, _ = mne.events_from_annotations(raw, event_id=event_id) if model_type.lower() == 'baseline': model = _train_baseline_model(raw, events) elif model_type.lower() == 'riemann': # TODO: load subject config model = _train_riemann_model(raw, events) else: raise NotImplementedError return model def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)]): fs = raw.info['sfreq'] n_ch = len(raw.ch_names) feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands) filtered_data = feat_extractor.transform(raw.get_data()) # TODO: find proper latency X = bci_utils.cut_epochs((0, 1., fs), filtered_data, events[:, 0]) y = events[:, -1] scaler = bci_model.ChannelScaler() X = scaler.fit_transform(X) # compute covariance lfb_dim = len(lfb_bands) * n_ch hgs_dim = len(hg_bands) * n_ch cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf') X_cov = cov_model.fit_transform(X) param = {'C_lfb': np.logspace(-4, 0, 5), 'C_hgs': np.logspace(-3, 1, 5)} model_func = partial(bci_model.stacking_riemann, lfb_dim=lfb_dim, hgs_dim=hgs_dim) best_auc, best_param = bci_utils.param_search(model_func, X_cov, y, param) logging.info(f'Best parameter: {best_param}, best auc {best_auc}') # train and dump best model model_to_train = model_func(**best_param) model_to_train.fit(X_cov, y) return [feat_extractor, scaler, cov_model, model_to_train] def _train_baseline_model(raw, events): fs = raw.info['sfreq'] filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True) filter_bank_epoch = bci_utils.cut_epochs((0, 1., fs), filter_bank_data, events[:, 0]) # downsampling to 10 Hz # decim 2 times, to 100Hz decimate_rate = np.sqrt(fs / 10).astype(np.int16) filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) # to 10Hz filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) X = filter_bank_epoch y = events[:, -1] best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)}) logging.info(f'Best parameter: {best_param}, best auc {best_auc}') model_to_train = bci_model.baseline_model(**best_param) model_to_train.fit(X, y) return model_to_train def model_saver(model, model_path, model_type, subject_id, event_id): # event list should be sorted by class label sorted_events = sorted(event_id.items(), key=lambda item: item[1]) # Extract the keys in the sorted order and store them in a list sorted_events = [item[0] for item in sorted_events] try: os.mkdir(os.path.join(model_path, subject_id)) except FileExistsError: pass now = datetime.now() classes = '+'.join(sorted_events) date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S") model_name = f'{model_type}_{classes}_{date_time_str}.pkl' joblib.dump(model, os.path.join(model_path, subject_id, model_name)) if __name__ == '__main__': # TODO: argparse subj_name = 'ylj' model_type = 'baseline' # TODO: load subject config data_dir = f'./data/{subj_name}/' model_dir = './static/models/' with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] event_id = {'rest': 0} for f in sessions.keys(): event_id[f] = neo.FINGERMODEL_IDS[f] # preprocess raw raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7]) # train model model = train_model(raw, event_id=event_id, model_type=model_type) # save model_saver(model, model_dir, model_type, subj_name, event_id)