train_hmm.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. '''
  2. Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
  3. '''
  4. import os
  5. import argparse
  6. from hmmlearn import hmm
  7. import numpy as np
  8. import yaml
  9. import joblib
  10. from scipy import signal
  11. import matplotlib.pyplot as plt
  12. from dataloaders import neo
  13. import bci_core.utils as bci_utils
  14. from settings.config import settings
  15. config_info = settings.CONFIG_INFO
  16. class HMMClassifier(hmm.BaseHMM):
  17. def __init__(self, emission_model, **kwargs):
  18. n_components = len(emission_model.classes_)
  19. super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
  20. self.emission_model = emission_model
  21. def _check_and_set_n_features(self, X):
  22. if X.ndim == 2: #
  23. n_features = X.shape[1]
  24. elif X.ndim == 3:
  25. n_features = X.shape[1] * X.shape[2]
  26. else:
  27. raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
  28. if hasattr(self, "n_features"):
  29. if self.n_features != n_features:
  30. raise ValueError(
  31. f"Unexpected number of dimensions, got {n_features} but "
  32. f"expected {self.n_features}")
  33. else:
  34. self.n_features = n_features
  35. def _get_n_fit_scalars_per_param(self):
  36. nc = self.n_components
  37. return {
  38. "s": nc,
  39. "t": nc ** 2}
  40. def _compute_likelihood(self, X):
  41. p = self.emission_model.predict_proba(X)
  42. return p
  43. def extract_embedded_feature(model, raw, step=0.1, buffer_length=0.5):
  44. fs = raw.info['sfreq']
  45. feat_extractor, embedder, _ = model
  46. filtered_data = feat_extractor.transform(raw.get_data())
  47. timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
  48. X = bci_utils.cut_epochs((0, buffer_length, fs), filtered_data, timestamps)
  49. X_embed = embedder.transform(X)
  50. return X_embed
  51. def _split_continuous(time_range, step, fs, window_size):
  52. return np.arange(int(time_range[0] * fs),
  53. int(time_range[-1] * fs) - int(window_size * fs),
  54. int(step * fs), dtype=np.int64)
  55. def parse_args():
  56. parser = argparse.ArgumentParser(
  57. description='Model validation'
  58. )
  59. parser.add_argument(
  60. '--subj',
  61. dest='subj',
  62. help='Subject name',
  63. default=None,
  64. type=str
  65. )
  66. parser.add_argument(
  67. '--model-filename',
  68. dest='model_filename',
  69. help='Model filename',
  70. default=None,
  71. type=str
  72. )
  73. return parser.parse_args()
  74. if __name__ == '__main__':
  75. args = parse_args()
  76. # load model and fit hmm
  77. subj_name = args.subj
  78. model_filename = args.model_filename
  79. data_dir = os.path.join(settings.DATA_PATH, subj_name)
  80. model_path = os.path.join(settings.MODEL_PATH, subj_name, model_filename)
  81. # load model
  82. model = joblib.load(model_path)
  83. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  84. info = yaml.safe_load(f)
  85. sessions = info['hmm_sessions']
  86. raw, event_id = neo.raw_loader(data_dir, sessions, config_info['reref'])
  87. # cut into buffer len epochs
  88. feature = extract_embedded_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
  89. # initiate hmm model
  90. # TODO: building transmat init
  91. hmm_model = HMMClassifier(model[-1], n_iter=100)
  92. hmm_model.fit(feature)
  93. # decode
  94. log_probs, state_seqs = hmm_model.decode(feature)
  95. plt.figure()
  96. plt.plot(state_seqs)
  97. # save transmat
  98. np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
  99. plt.show()