training.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import logging
  2. import joblib
  3. import os
  4. from datetime import datetime
  5. import yaml
  6. import mne
  7. import numpy as np
  8. from scipy import signal
  9. from pyriemann.estimation import BlockCovariances
  10. import bci_core.feature_extractors as feature_extractors
  11. import bci_core.utils as bci_utils
  12. import bci_core.model as bci_model
  13. from dataloaders import neo
  14. logging.basicConfig(level=logging.INFO)
  15. def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
  16. """
  17. """
  18. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  19. if model_type.lower() == 'baseline':
  20. model = _train_baseline_model(raw, events, duration=trial_duration)
  21. elif model_type.lower() == 'riemann':
  22. # TODO: load subject config
  23. model = _train_riemann_model(raw, events, duration=trial_duration)
  24. else:
  25. raise NotImplementedError
  26. return model
  27. def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
  28. fs = raw.info['sfreq']
  29. n_ch = len(raw.ch_names)
  30. feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
  31. filtered_data = feat_extractor.transform(raw.get_data())
  32. # TODO: find proper latency
  33. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  34. y = events[:, -1]
  35. scaler = bci_model.ChannelScaler()
  36. X = scaler.fit_transform(X)
  37. # compute covariance
  38. lfb_dim = len(lfb_bands) * n_ch
  39. hgs_dim = len(hg_bands) * n_ch
  40. cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
  41. X_cov = cov_model.fit_transform(X)
  42. param = {'C': np.logspace(-5, 4, 10)}
  43. best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
  44. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  45. # train and dump best model
  46. model_to_train = bci_model.riemann_model(**best_param)
  47. model_to_train.fit(X_cov, y)
  48. return [feat_extractor, scaler, cov_model, model_to_train]
  49. def _train_baseline_model(raw, events, duration=1., ):
  50. fs = raw.info['sfreq']
  51. filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
  52. filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
  53. # downsampling to 10 Hz
  54. # decim 2 times, to 100Hz
  55. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  56. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  57. # to 10Hz
  58. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  59. X = filter_bank_epoch
  60. y = events[:, -1]
  61. best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
  62. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  63. model_to_train = bci_model.baseline_model(**best_param)
  64. model_to_train.fit(X, y)
  65. return model_to_train
  66. def model_saver(model, model_path, model_type, subject_id, event_id):
  67. # event list should be sorted by class label
  68. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  69. # Extract the keys in the sorted order and store them in a list
  70. sorted_events = [item[0] for item in sorted_events]
  71. try:
  72. os.mkdir(os.path.join(model_path, subject_id))
  73. except FileExistsError:
  74. pass
  75. now = datetime.now()
  76. classes = '+'.join(sorted_events)
  77. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  78. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  79. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  80. if __name__ == '__main__':
  81. # TODO: argparse
  82. subj_name = 'XW01'
  83. model_type = 'riemann'
  84. # TODO: load subject config
  85. data_dir = f'./data/{subj_name}/'
  86. model_dir = './static/models/'
  87. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  88. info = yaml.safe_load(f)
  89. sessions = info['sessions']
  90. event_id = {'rest': 0}
  91. for f in sessions.keys():
  92. event_id[f] = neo.FINGERMODEL_IDS[f]
  93. # preprocess raw
  94. raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
  95. # train model
  96. model = train_model(raw, event_id=event_id, model_type=model_type)
  97. # save
  98. model_saver(model, model_dir, model_type, subj_name, event_id)