training.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import logging
  2. import os
  3. import yaml
  4. import argparse
  5. import mne
  6. import numpy as np
  7. from sklearn.linear_model import LogisticRegression
  8. import bci_core.pipeline as bci_pipeline
  9. import bci_core.utils as bci_utils
  10. from dataloaders import neo
  11. from settings.config import settings
  12. logging.basicConfig(level=logging.INFO)
  13. logger = logging.getLogger(__name__)
  14. config_info = settings.CONFIG_INFO
  15. def parse_args():
  16. parser = argparse.ArgumentParser(
  17. description='Model validation'
  18. )
  19. parser.add_argument(
  20. '--subj',
  21. dest='subj',
  22. help='Subject name',
  23. default=None,
  24. type=str
  25. )
  26. parser.add_argument(
  27. '--model-type',
  28. dest='model_type',
  29. default='baseline',
  30. type=str
  31. )
  32. return parser.parse_args()
  33. def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
  34. """
  35. """
  36. events, _ = mne.events_from_annotations(raw, event_id=event_id)
  37. fs = raw.info['sfreq']
  38. n_ch = len(raw.ch_names)
  39. if model_type.lower() == 'baseline':
  40. feat_extractor, embedder = bci_pipeline.baseline_model_builder(fs=fs, target_fs=10, **model_kwargs)
  41. elif model_type.lower() == 'riemann':
  42. feat_extractor, embedder = bci_pipeline.riemann_model_builder(fs=fs, n_ch=n_ch, **model_kwargs)
  43. else:
  44. raise NotImplementedError
  45. classifier = _param_search([feat_extractor, embedder], raw, trial_duration, events)
  46. return [feat_extractor, embedder, classifier]
  47. def _param_search(model, raw, duration, events):
  48. fs = raw.info['sfreq']
  49. feat_extractor, embedder = model
  50. filtered_data = feat_extractor.transform(raw.get_data())
  51. X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  52. y = events[:, -1]
  53. # embed feature
  54. X_embed = embedder.fit_transform(X)
  55. param = {'C': np.logspace(-5, 4, 10)}
  56. best_auc, best_param = bci_utils.param_search(LogisticRegression, X_embed, y, param)
  57. logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
  58. # train and dump best model
  59. model_for_train = LogisticRegression(**best_param)
  60. model_for_train.fit(X_embed, y)
  61. return model_for_train
  62. if __name__ == '__main__':
  63. args = parse_args()
  64. subj_name = args.subj
  65. model_type = args.model_type
  66. data_dir = os.path.join(settings.DATA_PATH, subj_name)
  67. model_dir = settings.MODEL_PATH
  68. with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
  69. model_config = yaml.safe_load(f)[model_type]
  70. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  71. info = yaml.safe_load(f)
  72. sessions = info['sessions']
  73. upsampled_trial_duration = config_info['buffer_length']
  74. ori_epoch_length = info.get('ori_epoch_length', 5.)
  75. # preprocess raw
  76. raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=upsampled_trial_duration, ori_epoch_length=ori_epoch_length, reref_method=config_info['reref'])
  77. # train model
  78. model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=upsampled_trial_duration, **model_config)
  79. # save
  80. bci_utils.model_saver(model, model_dir, model_type, subj_name, event_id)