1
0

training.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import logging
  2. import joblib
  3. import os
  4. from datetime import datetime
  5. from functools import partial
  6. import yaml
  7. import mne
  8. import numpy as np
  9. from scipy import signal
  10. from pyriemann.estimation import BlockCovariances
  11. import bci_core.feature_extractors as feature_extractors
  12. import bci_core.utils as bci_utils
  13. import bci_core.model as bci_model
  14. from dataloaders import neo
  15. def train_model(raw, event_id, model_type='baseline'):
  16. """
  17. """
  18. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  19. if model_type.lower() == 'baseline':
  20. model = _train_baseline_model(raw, events)
  21. elif model_type.lower() == 'riemann':
  22. # TODO: load subject config
  23. model = _train_riemann_model(raw, events)
  24. else:
  25. raise NotImplementedError
  26. return model
  27. def _train_riemann_model(raw, events, lfb_bands=[(15, 30), [30, 45]], hg_bands=[(55, 95), (105, 145)]):
  28. fs = raw.info['sfreq']
  29. n_ch = len(raw.ch_names)
  30. feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
  31. filtered_data = feat_extractor.transform(raw.get_data())
  32. # TODO: find proper latency
  33. X = bci_utils.cut_epochs((0, 1., fs), filtered_data, events[:, 0])
  34. y = events[:, -1]
  35. scaler = bci_model.ChannelScaler()
  36. X = scaler.fit_transform(X)
  37. # compute covariance
  38. lfb_dim = len(lfb_bands) * n_ch
  39. hgs_dim = len(hg_bands) * n_ch
  40. cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
  41. X_cov = cov_model.fit_transform(X)
  42. param = {'C_lfb': np.logspace(-4, 0, 5), 'C_hgs': np.logspace(-3, 1, 5)}
  43. model_func = partial(bci_model.stacking_riemann, lfb_dim=lfb_dim, hgs_dim=hgs_dim)
  44. best_auc, best_param = bci_utils.param_search(model_func, X_cov, y, param)
  45. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  46. # train and dump best model
  47. model_to_train = model_func(**best_param)
  48. model_to_train.fit(X, y)
  49. return [feat_extractor, scaler, cov_model, model_to_train]
  50. def _train_baseline_model(raw, events):
  51. fs = raw.info['sfreq']
  52. filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
  53. filter_bank_epoch = bci_utils.cut_epochs((0, 1., fs), filter_bank_data, events[:, 0])
  54. # downsampling to 10 Hz
  55. # decim 2 times, to 100Hz
  56. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  57. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  58. # to 10Hz
  59. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  60. X = filter_bank_epoch
  61. y = events[:, -1]
  62. best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
  63. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  64. model_to_train = bci_model.baseline_model(**best_param)
  65. model_to_train.fit(X, y)
  66. return model_to_train
  67. def model_saver(model, model_path, model_type, subject_id, event_id):
  68. # event list should be sorted by class label
  69. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  70. # Extract the keys in the sorted order and store them in a list
  71. sorted_events = [item[0] for item in sorted_events]
  72. try:
  73. os.mkdir(os.path.join(model_path, subject_id))
  74. except FileExistsError:
  75. pass
  76. now = datetime.now()
  77. classes = '+'.join(sorted_events)
  78. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  79. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  80. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  81. if __name__ == '__main__':
  82. # TODO: argparse
  83. subj_name = 'ylj'
  84. model_type = 'baseline'
  85. # TODO: load subject config
  86. data_dir = f'./data/{subj_name}/train/'
  87. model_dir = './static/models/'
  88. info = yaml.safe_load(os.path.join(data_dir, 'info.yml'))
  89. sessions = info['sessions']
  90. event_id = {'rest': 0}
  91. for f in sessions.keys():
  92. event_id[f] = neo.FINGERMODEL_IDS[f]
  93. # preprocess raw
  94. raw = neo.raw_preprocessing(data_dir, sessions, rename_event=False, ori_epoch_length=5)
  95. # train model
  96. model = train_model(raw, event_id=event_id, model_type=model_type)
  97. # save
  98. model_saver(model, model_dir, model_type, subj_name, event_id)