2
0

train_hmm.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. '''
  2. Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
  3. '''
  4. import os
  5. import argparse
  6. from hmmlearn import hmm
  7. import numpy as np
  8. import yaml
  9. import joblib
  10. from scipy import signal
  11. import matplotlib.pyplot as plt
  12. from dataloaders import neo
  13. import bci_core.utils as bci_utils
  14. from settings.config import settings
  15. config_info = settings.CONFIG_INFO
  16. class HMMClassifier(hmm.BaseHMM):
  17. # TODO: 如何绕过hmmlearn里使用的sklearn.utils.validation.check_array,目前我直接修改了hmmlearn的源码(删除所有的check_array)
  18. # TODO: 可行的方法是修改模型组织,将特征提取步骤与最终分类器分开,模型只保留最终分类器,这样仅需接收二维特征。
  19. def __init__(self, emission_model, **kwargs):
  20. n_components = len(emission_model.classes_)
  21. super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
  22. self.emission_model = emission_model
  23. def _check_and_set_n_features(self, X):
  24. if X.ndim == 2: #
  25. n_features = X.shape[1]
  26. elif X.ndim == 3:
  27. n_features = X.shape[1] * X.shape[2]
  28. else:
  29. raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
  30. if hasattr(self, "n_features"):
  31. if self.n_features != n_features:
  32. raise ValueError(
  33. f"Unexpected number of dimensions, got {n_features} but "
  34. f"expected {self.n_features}")
  35. else:
  36. self.n_features = n_features
  37. def _get_n_fit_scalars_per_param(self):
  38. nc = self.n_components
  39. return {
  40. "s": nc,
  41. "t": nc ** 2}
  42. def _compute_likelihood(self, X):
  43. p = self.emission_model.predict_proba(X)
  44. return p
  45. def extract_baseline_feature(model, raw, step):
  46. fs = raw.info['sfreq']
  47. feat_extractor, _ = model
  48. filter_bank_data = feat_extractor.transform(raw.get_data())
  49. timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
  50. filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
  51. # decimate
  52. decimate_rate = np.sqrt(fs / 10).astype(np.int16)
  53. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  54. # to 10Hz
  55. filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
  56. return filter_bank_epoch
  57. def extract_riemann_feature(model, raw, step):
  58. fs = raw.info['sfreq']
  59. feat_extractor, scaler, cov_model, _ = model
  60. filtered_data = feat_extractor.transform(raw.get_data())
  61. timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
  62. X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
  63. X = scaler.transform(X)
  64. X_cov = cov_model.transform(X)
  65. return X_cov
  66. def _split_continuous(time_range, step, fs):
  67. return np.arange(int(time_range[0] * fs),
  68. int(time_range[-1] * fs),
  69. int(step * fs), dtype=np.int64)
  70. def parse_args():
  71. parser = argparse.ArgumentParser(
  72. description='Model validation'
  73. )
  74. parser.add_argument(
  75. '--subj',
  76. dest='subj',
  77. help='Subject name',
  78. default=None,
  79. type=str
  80. )
  81. parser.add_argument(
  82. '--state-change-threshold',
  83. '-scth',
  84. dest='state_change_threshold',
  85. help='Threshold for HMM state change',
  86. default=0.75,
  87. type=float
  88. )
  89. parser.add_argument(
  90. '--state-trans-prob',
  91. '-stp',
  92. dest='state_trans_prob',
  93. help='Transition probability for HMM state change',
  94. default=0.8,
  95. type=float
  96. )
  97. parser.add_argument(
  98. '--model-filename',
  99. dest='model_filename',
  100. help='Model filename',
  101. default=None,
  102. type=str
  103. )
  104. return parser.parse_args()
  105. args = parse_args()
  106. # load model and fit hmm
  107. subj_name = args.subj
  108. model_filename = args.model_filename
  109. data_dir = f'./data/{subj_name}/'
  110. model_path = f'./static/models/{subj_name}/{model_filename}'
  111. # load model
  112. model_type, _ = bci_utils.parse_model_type(model_filename)
  113. model = joblib.load(model_path)
  114. with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
  115. info = yaml.safe_load(f)
  116. sessions = info['hmm_sessions']
  117. raw, event_id = neo.raw_loader(data_dir, sessions, True)
  118. # cut into buffer len epochs
  119. if model_type == 'baseline':
  120. feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
  121. elif model_type == 'riemann':
  122. feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
  123. else:
  124. raise ValueError
  125. # initiate hmm model
  126. hmm_model = HMMClassifier(model[-1], n_iter=100)
  127. hmm_model.fit(feature)
  128. # decode
  129. log_probs, state_seqs = hmm_model.decode(feature)
  130. plt.figure()
  131. plt.plot(state_seqs)
  132. # save transmat
  133. np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
  134. plt.show()