123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import logging
- import joblib
- import os
- from datetime import datetime
- import yaml
- import argparse
- import mne
- import numpy as np
- from scipy import signal
- from pyriemann.estimation import BlockCovariances
- import bci_core.feature_extractors as feature_extractors
- import bci_core.utils as bci_utils
- import bci_core.model as bci_model
- from dataloaders import neo
- from settings.config import settings
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--model-type',
- dest='model_type',
- default='baseline',
- type=str
- )
- return parser.parse_args()
- def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
- """
- """
- events, _ = mne.events_from_annotations(raw, event_id=event_id)
- if model_type.lower() == 'baseline':
- model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs)
- elif model_type.lower() == 'riemann':
- model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs)
- else:
- raise NotImplementedError
- return model
- def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
- fs = raw.info['sfreq']
- n_ch = len(raw.ch_names)
- feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands)
- filtered_data = feat_extractor.transform(raw.get_data())
- X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
- y = events[:, -1]
-
- scaler = bci_model.ChannelScaler()
- X = scaler.fit_transform(X)
- # compute covariance
- lfb_dim = len(lf_bands) * n_ch
- hgs_dim = len(hg_bands) * n_ch
- cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
- X_cov = cov_model.fit_transform(X)
- param = {'C': np.logspace(-5, 4, 10)}
- best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
- logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
- # train and dump best model
- model_to_train = bci_model.riemann_model(**best_param)
- model_to_train.fit(X_cov, y)
- return [feat_extractor, scaler, cov_model, model_to_train]
- def _train_baseline_model(raw, events, duration=1., freqs=(20, 150, 15)):
- fs = raw.info['sfreq']
- freqs = np.arange(*freqs)
- filterbank_extractor = feature_extractors.FilterbankExtractor(fs, freqs)
- filter_bank_data = filterbank_extractor.transform(raw.get_data())
- filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
- # downsampling to 10 Hz
- # decim 2 times, to 100Hz
- decimate_rate = np.sqrt(fs / 10).astype(np.int16)
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- # to 10Hz
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- X = filter_bank_epoch
- y = events[:, -1]
- best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
- logger.info(f'Best parameter: {best_param}, best auc {best_auc}')
- model_to_train = bci_model.baseline_model(**best_param)
- model_to_train.fit(X, y)
- return filterbank_extractor, model_to_train
- def model_saver(model, model_path, model_type, subject_id, event_id):
- # event list should be sorted by class label
- sorted_events = sorted(event_id.items(), key=lambda item: item[1])
- # Extract the keys in the sorted order and store them in a list
- sorted_events = [item[0] for item in sorted_events]
- try:
- os.mkdir(os.path.join(model_path, subject_id))
- except FileExistsError:
- pass
- now = datetime.now()
- classes = '+'.join(sorted_events)
- date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
- model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
- joblib.dump(model, os.path.join(model_path, subject_id, model_name))
- if __name__ == '__main__':
- args = parse_args()
- subj_name = args.subj
- model_type = args.model_type
- data_dir = f'./data/{subj_name}/'
- model_dir = './static/models/'
- with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
- model_config = yaml.safe_load(f)[model_type]
- with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
-
- trial_duration = config_info['buffer_length']
- # preprocess raw
- raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
- # train model
- model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
-
- # save
- model_saver(model, model_dir, model_type, subj_name, event_id)
|