training.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import logging
  2. import joblib
  3. import os
  4. from datetime import datetime
  5. import yaml
  6. import argparse
  7. import mne
  8. import numpy as np
  9. from scipy import signal
  10. from pyriemann.estimation import BlockCovariances
  11. import bci_core.feature_extractors as feature_extractors
  12. import bci_core.utils as bci_utils
  13. import bci_core.model as bci_model
  14. from dataloaders import neo
  15. from settings.config import settings
  16. logging.basicConfig(level=logging.INFO)
  17. logger = logging.getLogger(__name__)
  18. config_info = settings.CONFIG_INFO
  19. def parse_args():
  20. parser = argparse.ArgumentParser(
  21. description='Model validation'
  22. )
  23. parser.add_argument(
  24. '--subj',
  25. dest='subj',
  26. help='Subject name',
  27. default=None,
  28. type=str
  29. )
  30. parser.add_argument(
  31. '--model-type',
  32. dest='model_type',
  33. default='baseline',
  34. type=str
  35. )
  36. return parser.parse_args()
  37. def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
  38. """
  39. """
  40. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  41. if model_type.lower() == 'baseline':
  42. model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs)
  43. elif model_type.lower() == 'riemann':
  44. model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs)
  45. else:
  46. raise NotImplementedError
  47. return model
  48. def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
  49. fs = raw.info['sfreq']
  50. n_ch = len(raw.ch_names)
  51. feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands)
  52. filtered_data = feat_extractor.transform(raw.get_data())
  53. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  54. y = events[:, -1]
  55. scaler = bci_model.ChannelScaler()
  56. X = scaler.fit_transform(X)
  57. # compute covariance
  58. feat_dim = []
  59. if lf_bands is not None:
  60. feat_dim.append(len(lf_bands) * n_ch)
  61. if hg_bands is not None:
  62. feat_dim.append(len(hg_bands) * n_ch)
  63. cov_model = BlockCovariances(feat_dim, estimator='lwf')
  64. X_cov = cov_model.fit_transform(X)
  65. param = {'C': np.logspace(-5, 4, 10)}
  66. best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
  67. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  68. # train and dump best model
  69. model_to_train = bci_model.riemann_model(**best_param)
  70. model_to_train.fit(X_cov, y)
  71. return [feat_extractor, scaler, cov_model, model_to_train]
  72. def _train_baseline_model(raw, events, duration=1., freqs=(20, 150, 15)):
  73. fs = raw.info['sfreq']
  74. freqs = np.arange(*freqs)
  75. filterbank_extractor = feature_extractors.FilterbankExtractor(fs, freqs)
  76. filter_bank_data = filterbank_extractor.transform(raw.get_data())
  77. filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
  78. # downsampling to 10 Hz
  79. # decim 2 times, to 100Hz
  80. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  81. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  82. # to 10Hz
  83. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  84. X = filter_bank_epoch
  85. y = events[:, -1]
  86. best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
  87. logger.info(f'Best parameter: {best_param}, best auc {best_auc}')
  88. model_to_train = bci_model.baseline_model(**best_param)
  89. model_to_train.fit(X, y)
  90. return filterbank_extractor, model_to_train
  91. def model_saver(model, model_path, model_type, subject_id, event_id):
  92. # event list should be sorted by class label
  93. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  94. # Extract the keys in the sorted order and store them in a list
  95. sorted_events = [item[0] for item in sorted_events]
  96. try:
  97. os.mkdir(os.path.join(model_path, subject_id))
  98. except FileExistsError:
  99. pass
  100. now = datetime.now()
  101. classes = '+'.join(sorted_events)
  102. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  103. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  104. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  105. if __name__ == '__main__':
  106. args = parse_args()
  107. subj_name = args.subj
  108. model_type = args.model_type
  109. data_dir = f'./data/{subj_name}/'
  110. model_dir = './static/models/'
  111. with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
  112. model_config = yaml.safe_load(f)[model_type]
  113. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  114. info = yaml.safe_load(f)
  115. sessions = info['sessions']
  116. trial_duration = config_info['buffer_length']
  117. # preprocess raw
  118. raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5, reref_method=config_info['reref'])
  119. # train model
  120. model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
  121. # save
  122. model_saver(model, model_dir, model_type, subj_name, event_id)