123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- import logging
- import joblib
- import os
- from datetime import datetime
- from functools import partial
- import yaml
- import mne
- import numpy as np
- from scipy import signal
- from pyriemann.estimation import BlockCovariances
- import bci_core.feature_extractors as feature_extractors
- import bci_core.utils as bci_utils
- import bci_core.model as bci_model
- from dataloaders import neo
- def train_model(raw, event_id, model_type='baseline'):
- """
- """
- events, _ = mne.events_from_annotations(raw, event_id=event_id)
- if model_type.lower() == 'baseline':
- model = _train_baseline_model(raw, events)
- elif model_type.lower() == 'riemann':
- # TODO: load subject config
- model = _train_riemann_model(raw, events)
- else:
- raise NotImplementedError
- return model
- def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)]):
- fs = raw.info['sfreq']
- n_ch = len(raw.ch_names)
- feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
- filtered_data = feat_extractor.transform(raw.get_data())
- # TODO: find proper latency
- X = bci_utils.cut_epochs((0, 1., fs), filtered_data, events[:, 0])
- y = events[:, -1]
-
- scaler = bci_model.ChannelScaler()
- X = scaler.fit_transform(X)
- # compute covariance
- lfb_dim = len(lfb_bands) * n_ch
- hgs_dim = len(hg_bands) * n_ch
- cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
- X_cov = cov_model.fit_transform(X)
- param = {'C_lfb': np.logspace(-4, 0, 5), 'C_hgs': np.logspace(-3, 1, 5)}
- model_func = partial(bci_model.stacking_riemann, lfb_dim=lfb_dim, hgs_dim=hgs_dim)
- best_auc, best_param = bci_utils.param_search(model_func, X_cov, y, param)
- logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
- # train and dump best model
- model_to_train = model_func(**best_param)
- model_to_train.fit(X, y)
- return [feat_extractor, scaler, cov_model, model_to_train]
- def _train_baseline_model(raw, events):
- fs = raw.info['sfreq']
- filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
- filter_bank_epoch = bci_utils.cut_epochs((0, 1., fs), filter_bank_data, events[:, 0])
- # downsampling to 10 Hz
- # decim 2 times, to 100Hz
- decimate_rate = np.sqrt(fs / 10).astype(np.int16)
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- # to 10Hz
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- X = filter_bank_epoch
- y = events[:, -1]
- best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
- logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
- model_to_train = bci_model.baseline_model(**best_param)
- model_to_train.fit(X, y)
- return model_to_train
- def model_saver(model, model_path, model_type, subject_id, event_id):
- # event list should be sorted by class label
- sorted_events = sorted(event_id.items(), key=lambda item: item[1])
- # Extract the keys in the sorted order and store them in a list
- sorted_events = [item[0] for item in sorted_events]
- try:
- os.mkdir(os.path.join(model_path, subject_id))
- except FileExistsError:
- pass
- now = datetime.now()
- classes = '+'.join(sorted_events)
- date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
- model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
- joblib.dump(model, os.path.join(model_path, subject_id, model_name))
- if __name__ == '__main__':
- # TODO: argparse
- subj_name = 'ylj'
- model_type = 'baseline'
- # TODO: load subject config
- data_dir = f'./data/{subj_name}/train/'
- model_dir = './static/models/'
- info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
- sessions = info['sessions']
- event_id = {'rest': 0}
- for f in sessions.keys():
- event_id[f] = neo.FINGERMODEL_IDS[f]
-
- # preprocess raw
- raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5)
- # train model
- model = train_model(raw, event_id=event_id, model_type=model_type)
-
- # save
- model_saver(model, model_dir, model_type, subj_name, event_id)
|