''' Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks ''' import os import argparse from hmmlearn import hmm import numpy as np import yaml import joblib from scipy import signal import matplotlib.pyplot as plt from dataloaders import neo import bci_core.utils as bci_utils from settings.config import settings config_info = settings.CONFIG_INFO class HMMClassifier(hmm.BaseHMM): # TODO: 如何绕过hmmlearn里使用的sklearn.utils.validation.check_array,目前我直接修改了hmmlearn的源码(删除所有的check_array) # TODO: 可行的方法是修改模型组织,将特征提取步骤与最终分类器分开,模型只保留最终分类器,这样仅需接收二维特征。 def __init__(self, emission_model, **kwargs): n_components = len(emission_model.classes_) super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs) self.emission_model = emission_model def _check_and_set_n_features(self, X): if X.ndim == 2: # n_features = X.shape[1] elif X.ndim == 3: n_features = X.shape[1] * X.shape[2] else: raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3') if hasattr(self, "n_features"): if self.n_features != n_features: raise ValueError( f"Unexpected number of dimensions, got {n_features} but " f"expected {self.n_features}") else: self.n_features = n_features def _get_n_fit_scalars_per_param(self): nc = self.n_components return { "s": nc, "t": nc ** 2} def _compute_likelihood(self, X): p = self.emission_model.predict_proba(X) return p def extract_baseline_feature(model, raw, step): fs = raw.info['sfreq'] feat_extractor, _ = model filter_bank_data = feat_extractor.transform(raw.get_data()) timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs) filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps) # decimate decimate_rate = np.sqrt(fs / 10).astype(np.int16) filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) # to 10Hz filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) return filter_bank_epoch def extract_riemann_feature(model, raw, step): fs = raw.info['sfreq'] feat_extractor, scaler, cov_model, _ = model filtered_data = feat_extractor.transform(raw.get_data()) timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs) X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps) X = scaler.transform(X) X_cov = cov_model.transform(X) return X_cov def _split_continuous(time_range, step, fs): return np.arange(int(time_range[0] * fs), int(time_range[-1] * fs), int(step * fs), dtype=np.int64) def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--state-change-threshold', '-scth', dest='state_change_threshold', help='Threshold for HMM state change', default=0.75, type=float ) parser.add_argument( '--state-trans-prob', '-stp', dest='state_trans_prob', help='Transition probability for HMM state change', default=0.8, type=float ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() args = parse_args() # load model and fit hmm subj_name = args.subj model_filename = args.model_filename data_dir = f'./data/{subj_name}/' model_path = f'./static/models/{subj_name}/{model_filename}' # load model model_type, _ = bci_utils.parse_model_type(model_filename) model = joblib.load(model_path) with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['hmm_sessions'] raw, event_id = neo.raw_loader(data_dir, sessions, True) # cut into buffer len epochs if model_type == 'baseline': feature = extract_baseline_feature(model, raw, config_info['buffer_length']) elif model_type == 'riemann': feature = extract_riemann_feature(model, raw, config_info['buffer_length']) else: raise ValueError # initiate hmm model hmm_model = HMMClassifier(model[-1], n_iter=100) hmm_model.fit(feature) # decode log_probs, state_seqs = hmm_model.decode(feature) plt.figure() plt.plot(state_seqs) # save transmat np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_) plt.show()