1
0

training.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import logging
  2. import joblib
  3. import os
  4. from datetime import datetime
  5. from functools import partial
  6. import yaml
  7. import mne
  8. import numpy as np
  9. from scipy import signal
  10. from pyriemann.estimation import BlockCovariances
  11. import bci_core.feature_extractors as feature_extractors
  12. import bci_core.utils as bci_utils
  13. import bci_core.model as bci_model
  14. from dataloaders import neo
  15. logging.basicConfig(level=logging.INFO)
  16. def train_model(raw, event_id, model_type='baseline'):
  17. """
  18. """
  19. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  20. if model_type.lower() == 'baseline':
  21. model = _train_baseline_model(raw, events)
  22. elif model_type.lower() == 'riemann':
  23. # TODO: load subject config
  24. model = _train_riemann_model(raw, events)
  25. else:
  26. raise NotImplementedError
  27. return model
  28. def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)]):
  29. fs = raw.info['sfreq']
  30. n_ch = len(raw.ch_names)
  31. feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
  32. filtered_data = feat_extractor.transform(raw.get_data())
  33. # TODO: find proper latency
  34. X = bci_utils.cut_epochs((0, 1., fs), filtered_data, events[:, 0])
  35. y = events[:, -1]
  36. scaler = bci_model.ChannelScaler()
  37. X = scaler.fit_transform(X)
  38. # compute covariance
  39. lfb_dim = len(lfb_bands) * n_ch
  40. hgs_dim = len(hg_bands) * n_ch
  41. cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
  42. X_cov = cov_model.fit_transform(X)
  43. param = {'C_lfb': np.logspace(-4, 0, 5), 'C_hgs': np.logspace(-3, 1, 5)}
  44. model_func = partial(bci_model.stacking_riemann, lfb_dim=lfb_dim, hgs_dim=hgs_dim)
  45. best_auc, best_param = bci_utils.param_search(model_func, X_cov, y, param)
  46. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  47. # train and dump best model
  48. model_to_train = model_func(**best_param)
  49. model_to_train.fit(X_cov, y)
  50. return [feat_extractor, scaler, cov_model, model_to_train]
  51. def _train_baseline_model(raw, events):
  52. fs = raw.info['sfreq']
  53. filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
  54. filter_bank_epoch = bci_utils.cut_epochs((0, 1., fs), filter_bank_data, events[:, 0])
  55. # downsampling to 10 Hz
  56. # decim 2 times, to 100Hz
  57. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  58. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  59. # to 10Hz
  60. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  61. X = filter_bank_epoch
  62. y = events[:, -1]
  63. best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
  64. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  65. model_to_train = bci_model.baseline_model(**best_param)
  66. model_to_train.fit(X, y)
  67. return model_to_train
  68. def model_saver(model, model_path, model_type, subject_id, event_id):
  69. # event list should be sorted by class label
  70. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  71. # Extract the keys in the sorted order and store them in a list
  72. sorted_events = [item[0] for item in sorted_events]
  73. try:
  74. os.mkdir(os.path.join(model_path, subject_id))
  75. except FileExistsError:
  76. pass
  77. now = datetime.now()
  78. classes = '+'.join(sorted_events)
  79. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  80. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  81. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  82. if __name__ == '__main__':
  83. # TODO: argparse
  84. subj_name = 'ylj'
  85. model_type = 'baseline'
  86. # TODO: load subject config
  87. data_dir = f'./data/{subj_name}/'
  88. model_dir = './static/models/'
  89. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  90. info = yaml.safe_load(f)
  91. sessions = info['sessions']
  92. event_id = {'rest': 0}
  93. for f in sessions.keys():
  94. event_id[f] = neo.FINGERMODEL_IDS[f]
  95. # preprocess raw
  96. raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[6], rest_trial_ind=[7])
  97. # train model
  98. model = train_model(raw, event_id=event_id, model_type=model_type)
  99. # save
  100. model_saver(model, model_dir, model_type, subj_name, event_id)