''' Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks ''' import os import argparse from hmmlearn import hmm import numpy as np import yaml import joblib from scipy import signal import matplotlib.pyplot as plt from dataloaders import neo import bci_core.utils as bci_utils from settings.config import settings config_info = settings.CONFIG_INFO class HMMClassifier(hmm.BaseHMM): def __init__(self, emission_model, **kwargs): n_components = len(emission_model.classes_) super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs) self.emission_model = emission_model def _check_and_set_n_features(self, X): if X.ndim == 2: # n_features = X.shape[1] elif X.ndim == 3: n_features = X.shape[1] * X.shape[2] else: raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3') if hasattr(self, "n_features"): if self.n_features != n_features: raise ValueError( f"Unexpected number of dimensions, got {n_features} but " f"expected {self.n_features}") else: self.n_features = n_features def _get_n_fit_scalars_per_param(self): nc = self.n_components return { "s": nc, "t": nc ** 2} def _compute_likelihood(self, X): p = self.emission_model.predict_proba(X) return p def extract_embedded_feature(model, raw, step=0.1, buffer_length=0.5): fs = raw.info['sfreq'] feat_extractor, embedder, _ = model filtered_data = feat_extractor.transform(raw.get_data()) timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length) X = bci_utils.cut_epochs((0, buffer_length, fs), filtered_data, timestamps) X_embed = embedder.transform(X) return X_embed def _split_continuous(time_range, step, fs, window_size): return np.arange(int(time_range[0] * fs), int(time_range[-1] * fs) - int(window_size * fs), int(step * fs), dtype=np.int64) def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--model-filename', dest='model_filename', help='Model filename', default=None, type=str ) return parser.parse_args() if __name__ == '__main__': args = parse_args() # load model and fit hmm subj_name = args.subj model_filename = args.model_filename data_dir = os.path.join(settings.DATA_PATH, subj_name) model_path = os.path.join(settings.MODEL_PATH, subj_name, model_filename) # load model model = joblib.load(model_path) with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['hmm_sessions'] raw, event_id = neo.raw_loader(data_dir, sessions, config_info['reref']) # cut into buffer len epochs feature = extract_embedded_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length']) # initiate hmm model # TODO: building transmat init hmm_model = HMMClassifier(model[-1], n_iter=100) hmm_model.fit(feature) # decode log_probs, state_seqs = hmm_model.decode(feature) plt.figure() plt.plot(state_seqs) # save transmat np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_) plt.show()