training.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import logging
  2. import joblib
  3. import os
  4. from datetime import datetime
  5. import yaml
  6. import argparse
  7. import mne
  8. import numpy as np
  9. from scipy import signal
  10. from pyriemann.estimation import BlockCovariances
  11. import bci_core.feature_extractors as feature_extractors
  12. import bci_core.utils as bci_utils
  13. import bci_core.model as bci_model
  14. from dataloaders import neo
  15. from settings.config import settings
  16. logging.basicConfig(level=logging.INFO)
  17. config_info = settings.CONFIG_INFO
  18. def parse_args():
  19. parser = argparse.ArgumentParser(
  20. description='Model validation'
  21. )
  22. parser.add_argument(
  23. '--subj',
  24. dest='subj',
  25. help='Subject name',
  26. default=None,
  27. type=str
  28. )
  29. parser.add_argument(
  30. '--model-type',
  31. dest='model_type',
  32. default='baseline',
  33. type=str
  34. )
  35. return parser.parse_args()
  36. def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
  37. """
  38. """
  39. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  40. if model_type.lower() == 'baseline':
  41. model = _train_baseline_model(raw, events, duration=trial_duration)
  42. elif model_type.lower() == 'riemann':
  43. # TODO: load subject config
  44. model = _train_riemann_model(raw, events, duration=trial_duration)
  45. else:
  46. raise NotImplementedError
  47. return model
  48. def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
  49. fs = raw.info['sfreq']
  50. n_ch = len(raw.ch_names)
  51. feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
  52. filtered_data = feat_extractor.transform(raw.get_data())
  53. # TODO: find proper latency
  54. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  55. y = events[:, -1]
  56. scaler = bci_model.ChannelScaler()
  57. X = scaler.fit_transform(X)
  58. # compute covariance
  59. lfb_dim = len(lfb_bands) * n_ch
  60. hgs_dim = len(hg_bands) * n_ch
  61. cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
  62. X_cov = cov_model.fit_transform(X)
  63. param = {'C': np.logspace(-5, 4, 10)}
  64. best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)
  65. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  66. # train and dump best model
  67. model_to_train = bci_model.riemann_model(**best_param)
  68. model_to_train.fit(X_cov, y)
  69. return [feat_extractor, scaler, cov_model, model_to_train]
  70. def _train_baseline_model(raw, events, duration=1., ):
  71. fs = raw.info['sfreq']
  72. filter_bank_data = feature_extractors.filterbank_extractor(raw.get_data(), fs, np.arange(20, 150, 15), reshape_freqs_dim=True)
  73. filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])
  74. # downsampling to 10 Hz
  75. # decim 2 times, to 100Hz
  76. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  77. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  78. # to 10Hz
  79. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  80. X = filter_bank_epoch
  81. y = events[:, -1]
  82. best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
  83. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  84. model_to_train = bci_model.baseline_model(**best_param)
  85. model_to_train.fit(X, y)
  86. return model_to_train
  87. def model_saver(model, model_path, model_type, subject_id, event_id):
  88. # event list should be sorted by class label
  89. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  90. # Extract the keys in the sorted order and store them in a list
  91. sorted_events = [item[0] for item in sorted_events]
  92. try:
  93. os.mkdir(os.path.join(model_path, subject_id))
  94. except FileExistsError:
  95. pass
  96. now = datetime.now()
  97. classes = '+'.join(sorted_events)
  98. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  99. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  100. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  101. if __name__ == '__main__':
  102. args = parse_args()
  103. subj_name = args.subj
  104. model_type = args.model_type
  105. # TODO: load subject config
  106. # include frequency band, model_type, upsampled_trial_duration
  107. data_dir = f'./data/{subj_name}/'
  108. model_dir = './static/models/'
  109. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  110. info = yaml.safe_load(f)
  111. sessions = info['sessions']
  112. event_id = {'rest': 0}
  113. for f in sessions.keys():
  114. event_id[f] = neo.FINGERMODEL_IDS[f]
  115. trial_duration = config_info['buffer_length']
  116. # preprocess raw
  117. raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
  118. # train model
  119. model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration)
  120. # save
  121. model_saver(model, model_dir, model_type, subj_name, event_id)