train_hmm.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. '''
  2. Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
  3. '''
  4. import os
  5. import argparse
  6. from hmmlearn import hmm
  7. import numpy as np
  8. import yaml
  9. import joblib
  10. from scipy import signal
  11. import matplotlib.pyplot as plt
  12. from dataloaders import neo
  13. import bci_core.utils as bci_utils
  14. from settings.config import settings
  15. config_info = settings.CONFIG_INFO
  16. class HMMClassifier(hmm.BaseHMM):
  17. # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array)
  18. def __init__(self, emission_model, **kwargs):
  19. n_components = len(emission_model.classes_)
  20. super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
  21. self.emission_model = emission_model
  22. def _check_and_set_n_features(self, X):
  23. if X.ndim == 2: #
  24. n_features = X.shape[1]
  25. elif X.ndim == 3:
  26. n_features = X.shape[1] * X.shape[2]
  27. else:
  28. raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
  29. if hasattr(self, "n_features"):
  30. if self.n_features != n_features:
  31. raise ValueError(
  32. f"Unexpected number of dimensions, got {n_features} but "
  33. f"expected {self.n_features}")
  34. else:
  35. self.n_features = n_features
  36. def _get_n_fit_scalars_per_param(self):
  37. nc = self.n_components
  38. return {
  39. "s": nc,
  40. "t": nc ** 2}
  41. def _compute_likelihood(self, X):
  42. p = self.emission_model.predict_proba(X)
  43. return p
  44. def extract_baseline_feature(model, raw, step):
  45. fs = raw.info['sfreq']
  46. feat_extractor, _ = model
  47. filter_bank_data = feat_extractor.transform(raw.get_data())
  48. timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
  49. filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
  50. # decimate
  51. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  52. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  53. # to 10Hz
  54. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  55. return filter_bank_epoch
  56. def extract_riemann_feature(model, raw, step):
  57. fs = raw.info['sfreq']
  58. feat_extractor, scaler, cov_model, _ = model
  59. filtered_data = feat_extractor.transform(raw.get_data())
  60. timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
  61. X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
  62. X = scaler.transform(X)
  63. X_cov = cov_model.transform(X)
  64. return X_cov
  65. def _split_continuous(time_range, step, fs):
  66. return np.arange(int(time_range[0] * fs),
  67. int(time_range[-1] * fs),
  68. int(step * fs), dtype=np.int64)
  69. def parse_args():
  70. parser = argparse.ArgumentParser(
  71. description='Model validation'
  72. )
  73. parser.add_argument(
  74. '--subj',
  75. dest='subj',
  76. help='Subject name',
  77. default=None,
  78. type=str
  79. )
  80. parser.add_argument(
  81. '--state-change-threshold',
  82. '-scth',
  83. dest='state_change_threshold',
  84. help='Threshold for HMM state change',
  85. default=0.75,
  86. type=float
  87. )
  88. parser.add_argument(
  89. '--state-trans-prob',
  90. '-stp',
  91. dest='state_trans_prob',
  92. help='Transition probability for HMM state change',
  93. default=0.8,
  94. type=float
  95. )
  96. parser.add_argument(
  97. '--model-filename',
  98. dest='model_filename',
  99. help='Model filename',
  100. default=None,
  101. type=str
  102. )
  103. return parser.parse_args()
  104. args = parse_args()
  105. # load model and fit hmm
  106. subj_name = args.subj
  107. model_filename = args.model_filename
  108. data_dir = f'./data/{subj_name}/'
  109. model_path = f'./static/models/{subj_name}/{model_filename}'
  110. # load model
  111. model_type, _ = bci_utils.parse_model_type(model_filename)
  112. model = joblib.load(model_path)
  113. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  114. info = yaml.safe_load(f)
  115. sessions = info['hmm_sessions']
  116. raw = neo.raw_loader(data_dir, sessions, True)
  117. # cut into buffer len epochs
  118. if model_type == 'baseline':
  119. feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
  120. elif model_type == 'riemann':
  121. feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
  122. else:
  123. raise ValueError
  124. # initiate hmm model
  125. hmm_model = HMMClassifier(model[-1], n_iter=100)
  126. hmm_model.fit(feature)
  127. # decode
  128. log_probs, state_seqs = hmm_model.decode(feature)
  129. plt.figure()
  130. plt.plot(state_seqs)
  131. # save transmat
  132. np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
  133. plt.show()