2
0

training.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import logging
  2. import os
  3. import yaml
  4. import argparse
  5. import mne
  6. import numpy as np
  7. from sklearn.linear_model import LogisticRegression
  8. import bci_core.pipeline as bci_pipeline
  9. import bci_core.utils as bci_utils
  10. from dataloaders import neo
  11. from settings.config import settings
  12. logging.basicConfig(level=logging.INFO)
  13. logger = logging.getLogger(__name__)
  14. config_info = settings.CONFIG_INFO
  15. def parse_args():
  16. parser = argparse.ArgumentParser(
  17. description='Model validation'
  18. )
  19. parser.add_argument(
  20. '--subj',
  21. dest='subj',
  22. help='Subject name',
  23. default=None,
  24. type=str
  25. )
  26. parser.add_argument(
  27. '--model-type',
  28. dest='model_type',
  29. default='baseline',
  30. type=str
  31. )
  32. return parser.parse_args()
  33. def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
  34. """
  35. """
  36. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  37. fs = raw.info['sfreq']
  38. n_ch = len(raw.ch_names)
  39. if model_type.lower() == 'baseline':
  40. feat_extractor, embedder = bci_pipeline.baseline_model_builder(fs=fs, target_fs=10, **model_kwargs)
  41. elif model_type.lower() == 'riemann':
  42. feat_extractor, embedder = bci_pipeline.riemann_model_builder(fs=fs, n_ch=n_ch, **model_kwargs)
  43. elif model_type.lower() == 'csp':
  44. feat_extractor, embedder = bci_pipeline.csp_model_builder(fs=fs, **model_kwargs)
  45. else:
  46. raise NotImplementedError
  47. classifier = _param_search([feat_extractor, embedder], raw, trial_duration, events)
  48. return [feat_extractor, embedder, classifier]
  49. def _param_search(model, raw, duration, events):
  50. fs = raw.info['sfreq']
  51. feat_extractor, embedder = model
  52. filtered_data = feat_extractor.transform(raw.get_data())
  53. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  54. y = events[:, -1]
  55. # embed feature
  56. X_embed = embedder.fit_transform(X, y)
  57. param = {'C': np.logspace(-5, 4, 10)}
  58. best_auc, best_param = bci_utils.param_search(LogisticRegression, X_embed, y, param)
  59. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  60. # train and dump best model
  61. model_for_train = LogisticRegression(**best_param)
  62. model_for_train.fit(X_embed, y)
  63. return model_for_train
  64. if __name__ == '__main__':
  65. args = parse_args()
  66. subj_name = args.subj
  67. model_type = args.model_type
  68. data_dir = os.path.join(settings.DATA_PATH, subj_name)
  69. model_dir = settings.MODEL_PATH
  70. with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
  71. model_config = yaml.safe_load(f)[model_type]
  72. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  73. info = yaml.safe_load(f)
  74. sessions = info['sessions']
  75. upsampled_trial_duration = config_info['buffer_length']
  76. ori_epoch_length = info.get('ori_epoch_length', 5.)
  77. # preprocess raw
  78. raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=upsampled_trial_duration, ori_epoch_length=ori_epoch_length, reref_method=config_info['reref'])
  79. # train model
  80. model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=upsampled_trial_duration, **model_config)
  81. # save
  82. bci_utils.model_saver(model, model_dir, model_type, subj_name, event_id)