training.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import logging
  2. import joblib
  3. import os
  4. from datetime import datetime
  5. import yaml
  6. import mne
  7. import numpy as np
  8. from scipy import signal
  9. from pyriemann.estimation import BlockCovariances
  10. import bci_core.feature_extractors as feature_extractors
  11. import bci_core.utils as bci_utils
  12. import bci_core.model as bci_model
  13. from dataloaders import neo
  14. from settings.config import settings
  15. logging.basicConfig(level=logging.INFO)
  16. config_info = settings.CONFIG_INFO
  17. def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
  18. """
  19. """
  20. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  21. if model_type.lower() == 'baseline':
  22. model = _train_baseline_model(raw, events, duration=trial_duration)
  23. elif model_type.lower() == 'riemann':
  24. # TODO: load subject config
  25. model = _train_riemann_model(raw, events, duration=trial_duration)
  26. else:
  27. raise NotImplementedError
  28. return model
  29. def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
  30. fs = raw.info['sfreq']
  31. n_ch = len(raw.ch_names)
  32. feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
  33. filtered_data = feat_extractor.transform(raw.get_data())
  34. # TODO: find proper latency
  35. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  36. y = events[:, -1]
  37. scaler = bci_model.ChannelScaler()
  38. X = scaler.fit_transform(X)
  39. # compute covariance
  40. lfb_dim = len(lfb_bands) * n_ch
  41. hgs_dim = len(hg_bands) * n_ch
  42. cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
  43. X_cov = cov_model.fit_transform(X)
  44. param = {'C': np.logspace(-5, 4, 10)}
  45. best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
  46. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  47. # train and dump best model
  48. model_to_train = bci_model.riemann_model(**best_param)
  49. model_to_train.fit(X_cov, y)
  50. return [feat_extractor, scaler, cov_model, model_to_train]
  51. def _train_baseline_model(raw, events, duration=1., ):
  52. fs = raw.info['sfreq']
  53. filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
  54. filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
  55. # downsampling to 10 Hz
  56. # decim 2 times, to 100Hz
  57. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  58. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  59. # to 10Hz
  60. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  61. X = filter_bank_epoch
  62. y = events[:, -1]
  63. best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
  64. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  65. model_to_train = bci_model.baseline_model(**best_param)
  66. model_to_train.fit(X, y)
  67. return model_to_train
  68. def model_saver(model, model_path, model_type, subject_id, event_id):
  69. # event list should be sorted by class label
  70. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  71. # Extract the keys in the sorted order and store them in a list
  72. sorted_events = [item[0] for item in sorted_events]
  73. try:
  74. os.mkdir(os.path.join(model_path, subject_id))
  75. except FileExistsError:
  76. pass
  77. now = datetime.now()
  78. classes = '+'.join(sorted_events)
  79. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  80. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  81. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  82. if __name__ == '__main__':
  83. # TODO: argparse
  84. subj_name = 'XW01'
  85. model_type = 'riemann'
  86. # TODO: load subject config
  87. # include frequency band, model_type, upsampled_trial_duration
  88. data_dir = f'./data/{subj_name}/'
  89. model_dir = './static/models/'
  90. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  91. info = yaml.safe_load(f)
  92. sessions = info['sessions']
  93. event_id = {'rest': 0}
  94. for f in sessions.keys():
  95. event_id[f] = neo.FINGERMODEL_IDS[f]
  96. trial_duration = config_info['buffer_length']
  97. # preprocess raw
  98. raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
  99. # train model
  100. model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration)
  101. # save
  102. model_saver(model, model_dir, model_type, subj_name, event_id)