training.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import logging
  2. import joblib
  3. import os
  4. from datetime import datetime
  5. import yaml
  6. import argparse
  7. import mne
  8. import numpy as np
  9. from scipy import signal
  10. from pyriemann.estimation import BlockCovariances
  11. import bci_core.feature_extractors as feature_extractors
  12. import bci_core.utils as bci_utils
  13. import bci_core.model as bci_model
  14. from dataloaders import neo
  15. from settings.config import settings
  16. logging.basicConfig(level=logging.INFO)
  17. logger = logging.getLogger(__name__)
  18. config_info = settings.CONFIG_INFO
  19. def parse_args():
  20. parser = argparse.ArgumentParser(
  21. description='Model validation'
  22. )
  23. parser.add_argument(
  24. '--subj',
  25. dest='subj',
  26. help='Subject name',
  27. default=None,
  28. type=str
  29. )
  30. parser.add_argument(
  31. '--model-type',
  32. dest='model_type',
  33. default='baseline',
  34. type=str
  35. )
  36. return parser.parse_args()
  37. def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
  38. """
  39. """
  40. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  41. if model_type.lower() == 'baseline':
  42. model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs)
  43. elif model_type.lower() == 'riemann':
  44. model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs)
  45. else:
  46. raise NotImplementedError
  47. return model
  48. def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
  49. fs = raw.info['sfreq']
  50. n_ch = len(raw.ch_names)
  51. feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands)
  52. filtered_data = feat_extractor.transform(raw.get_data())
  53. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  54. y = events[:, -1]
  55. scaler = bci_model.ChannelScaler()
  56. X = scaler.fit_transform(X)
  57. # compute covariance
  58. lfb_dim = len(lf_bands) * n_ch
  59. hgs_dim = len(hg_bands) * n_ch
  60. cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
  61. X_cov = cov_model.fit_transform(X)
  62. param = {'C': np.logspace(-5, 4, 10)}
  63. best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
  64. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  65. # train and dump best model
  66. model_to_train = bci_model.riemann_model(**best_param)
  67. model_to_train.fit(X_cov, y)
  68. return [feat_extractor, scaler, cov_model, model_to_train]
  69. def _train_baseline_model(raw, events, duration=1., freqs=(20, 150, 15)):
  70. fs = raw.info['sfreq']
  71. freqs = np.arange(*freqs)
  72. filterbank_extractor = feature_extractors.FilterbankExtractor(fs, freqs)
  73. filter_bank_data = filterbank_extractor.transform(raw.get_data())
  74. filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
  75. # downsampling to 10 Hz
  76. # decim 2 times, to 100Hz
  77. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  78. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  79. # to 10Hz
  80. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  81. X = filter_bank_epoch
  82. y = events[:, -1]
  83. best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
  84. logger.info(f'Best parameter: {best_param}, best auc {best_auc}')
  85. model_to_train = bci_model.baseline_model(**best_param)
  86. model_to_train.fit(X, y)
  87. return filterbank_extractor, model_to_train
  88. def model_saver(model, model_path, model_type, subject_id, event_id):
  89. # event list should be sorted by class label
  90. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  91. # Extract the keys in the sorted order and store them in a list
  92. sorted_events = [item[0] for item in sorted_events]
  93. try:
  94. os.mkdir(os.path.join(model_path, subject_id))
  95. except FileExistsError:
  96. pass
  97. now = datetime.now()
  98. classes = '+'.join(sorted_events)
  99. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  100. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  101. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  102. if __name__ == '__main__':
  103. args = parse_args()
  104. subj_name = args.subj
  105. model_type = args.model_type
  106. data_dir = f'./data/{subj_name}/'
  107. model_dir = './static/models/'
  108. with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
  109. model_config = yaml.safe_load(f)[model_type]
  110. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  111. info = yaml.safe_load(f)
  112. sessions = info['sessions']
  113. trial_duration = config_info['buffer_length']
  114. # preprocess raw
  115. raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
  116. # train model
  117. model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
  118. # save
  119. model_saver(model, model_dir, model_type, subj_name, event_id)