''' Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks ''' import os import argparse from hmmlearn import hmm import numpy as np import yaml import joblib from scipy import signal import matplotlib.pyplot as plt from dataloaders import neo import bci_core.utils as bci_utils from settings.config import settings config_info = settings.CONFIG_INFO class HMMClassifier(hmm.BaseHMM): # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array) def __init__(self, emission_model, **kwargs): n_components = len(emission_model.classes_) super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs) self.emission_model = emission_model def _check_and_set_n_features(self, X): if X.ndim == 2: # n_features = X.shape[1] elif X.ndim == 3: n_features = X.shape[1] * X.shape[2] else: raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3') if hasattr(self, "n_features"): if self.n_features != n_features: raise ValueError( f"Unexpected number of dimensions, got {n_features} but " f"expected {self.n_features}") else: self.n_features = n_features def _get_n_fit_scalars_per_param(self): nc = self.n_components return { "s": nc, "t": nc ** 2} def _compute_likelihood(self, X): p = self.emission_model.predict_proba(X) return p def extract_baseline_feature(model, raw, step): fs = raw.info['sfreq'] feat_extractor, _ = model filter_bank_data = feat_extractor.transform(raw.get_data()) timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs) filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps) # decimate decimate_rate = np.sqrt(fs / 10).astype(np.int16) filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) # to 10Hz filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True) return filter_bank_epoch def extract_riemann_feature(model, raw, step): fs = raw.info['sfreq'] feat_extractor, scaler, cov_model, _ = model filtered_data = feat_extractor.transform(raw.get_data()) timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs) X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps) X = scaler.transform(X) X_cov = cov_model.transform(X) return X_cov def _split_continuous(time_range, step, fs): return np.arange(int(time_range[0] * fs), int(time_range[-1] * fs), int(step * fs), dtype=np.int64) def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--state-change-threshold', '-scth', dest='state_change_threshold', help='Threshold for HMM state change', default=0.75, type=float ) parser.add_argument( '--state-trans-prob', '-stp', dest='state_trans_prob', help='Transition probability for HMM state change', default=0.8, type=float ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() args = parse_args() # load model and fit hmm subj_name = args.subj model_filename = args.model_filename data_dir = f'./data/{subj_name}/' model_path = f'./static/models/{subj_name}/{model_filename}' # load model model_type, _ = bci_utils.parse_model_type(model_filename) model = joblib.load(model_path) with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['hmm_sessions'] raw = neo.raw_loader(data_dir, sessions, True) # cut into buffer len epochs if model_type == 'baseline': feature = extract_baseline_feature(model, raw, config_info['buffer_length']) elif model_type == 'riemann': feature = extract_riemann_feature(model, raw, config_info['buffer_length']) else: raise ValueError # initiate hmm model hmm_model = HMMClassifier(model[-1], n_iter=100) hmm_model.fit(feature) # decode log_probs, state_seqs = hmm_model.decode(feature) plt.figure() plt.plot(state_seqs) # save transmat np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_) plt.show()