123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- '''
- Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
- '''
- import os
- import argparse
- from hmmlearn import hmm
- import numpy as np
- import yaml
- import joblib
- from scipy import signal
- import matplotlib.pyplot as plt
- from dataloaders import neo
- import bci_core.utils as bci_utils
- from settings.config import settings
- config_info = settings.CONFIG_INFO
- class HMMClassifier(hmm.BaseHMM):
- def __init__(self, emission_model, **kwargs):
- n_components = len(emission_model.classes_)
- super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
- self.emission_model = emission_model
-
- def _check_and_set_n_features(self, X):
- if X.ndim == 2: #
- n_features = X.shape[1]
- elif X.ndim == 3:
- n_features = X.shape[1] * X.shape[2]
- else:
- raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
- if hasattr(self, "n_features"):
- if self.n_features != n_features:
- raise ValueError(
- f"Unexpected number of dimensions, got {n_features} but "
- f"expected {self.n_features}")
- else:
- self.n_features = n_features
-
- def _get_n_fit_scalars_per_param(self):
- nc = self.n_components
- return {
- "s": nc,
- "t": nc ** 2}
-
- def _compute_likelihood(self, X):
- p = self.emission_model.predict_proba(X)
- return p
- def extract_embedded_feature(model, raw, step=0.1, buffer_length=0.5):
- fs = raw.info['sfreq']
- feat_extractor, embedder, _ = model
- filtered_data = feat_extractor.transform(raw.get_data())
- timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
- X = bci_utils.cut_epochs((0, buffer_length, fs), filtered_data, timestamps)
- X_embed = embedder.transform(X)
- return X_embed
- def _split_continuous(time_range, step, fs, window_size):
- return np.arange(int(time_range[0] * fs),
- int(time_range[-1] * fs) - int(window_size * fs),
- int(step * fs), dtype=np.int64)
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model filename',
- default=None,
- type=str
- )
- return parser.parse_args()
- if __name__ == '__main__':
- args = parse_args()
- # load model and fit hmm
- subj_name = args.subj
- model_filename = args.model_filename
- data_dir = os.path.join(settings.DATA_PATH, subj_name)
- model_path = os.path.join(settings.MODEL_PATH, subj_name, model_filename)
- # load model
- model = joblib.load(model_path)
- with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['hmm_sessions']
- raw, event_id = neo.raw_loader(data_dir, sessions, config_info['reref'])
- # cut into buffer len epochs
- feature = extract_embedded_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
- # initiate hmm model
- # TODO: building transmat init
- hmm_model = HMMClassifier(model[-1], n_iter=100)
- hmm_model.fit(feature)
- # decode
- log_probs, state_seqs = hmm_model.decode(feature)
- plt.figure()
- plt.plot(state_seqs)
- # save transmat
- np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
- plt.show()
|