123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- '''
- Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
- '''
- import os
- import argparse
- from hmmlearn import hmm
- import numpy as np
- import yaml
- import joblib
- from scipy import signal
- import matplotlib.pyplot as plt
- from dataloaders import neo
- import bci_core.utils as bci_utils
- from settings.config import settings
- config_info = settings.CONFIG_INFO
- class HMMClassifier(hmm.BaseHMM):
- # TODO: 如何绕过hmmlearn里使用的sklearn.utils.validation.check_array,目前我直接修改了hmmlearn的源码(删除所有的check_array)
- # TODO: 可行的方法是修改模型组织,将特征提取步骤与最终分类器分开,模型只保留最终分类器,这样仅需接收二维特征。
- def __init__(self, emission_model, **kwargs):
- n_components = len(emission_model.classes_)
- super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
- self.emission_model = emission_model
-
- def _check_and_set_n_features(self, X):
- if X.ndim == 2: #
- n_features = X.shape[1]
- elif X.ndim == 3:
- n_features = X.shape[1] * X.shape[2]
- else:
- raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
- if hasattr(self, "n_features"):
- if self.n_features != n_features:
- raise ValueError(
- f"Unexpected number of dimensions, got {n_features} but "
- f"expected {self.n_features}")
- else:
- self.n_features = n_features
-
- def _get_n_fit_scalars_per_param(self):
- nc = self.n_components
- return {
- "s": nc,
- "t": nc ** 2}
-
- def _compute_likelihood(self, X):
- p = self.emission_model.predict_proba(X)
- return p
- def extract_baseline_feature(model, raw, step=0.1, buffer_length=0.5):
- fs = raw.info['sfreq']
- feat_extractor, _ = model
- filter_bank_data = feat_extractor.transform(raw.get_data())
- timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
- filter_bank_epoch = bci_utils.cut_epochs((0, buffer_length, fs), filter_bank_data, timestamps)
- # decimate
- decimate_rate = np.sqrt(fs / 10).astype(np.int16)
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- # to 10Hz
- filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
- return filter_bank_epoch
- def extract_riemann_feature(model, raw, step=0.1, buffer_length=0.5):
- fs = raw.info['sfreq']
- feat_extractor, scaler, cov_model, _ = model
- filtered_data = feat_extractor.transform(raw.get_data())
- timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs, buffer_length)
- X = bci_utils.cut_epochs((0, buffer_length, fs), filtered_data, timestamps)
- X = scaler.transform(X)
- X_cov = cov_model.transform(X)
- return X_cov
- def _split_continuous(time_range, step, fs, window_size):
- return np.arange(int(time_range[0] * fs),
- int(time_range[-1] * fs) - int(window_size * fs),
- int(step * fs), dtype=np.int64)
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--state-change-threshold',
- '-scth',
- dest='state_change_threshold',
- help='Threshold for HMM state change',
- default=0.75,
- type=float
- )
- parser.add_argument(
- '--state-trans-prob',
- '-stp',
- dest='state_trans_prob',
- help='Transition probability for HMM state change',
- default=0.8,
- type=float
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model filename',
- default=None,
- type=str
- )
- return parser.parse_args()
- if __name__ == '__main__':
- args = parse_args()
- # load model and fit hmm
- subj_name = args.subj
- model_filename = args.model_filename
- data_dir = f'./data/{subj_name}/'
-
- model_path = f'./static/models/{subj_name}/{model_filename}'
- # load model
- model_type, _ = bci_utils.parse_model_type(model_filename)
- model = joblib.load(model_path)
- with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['hmm_sessions']
- raw, event_id = neo.raw_loader(data_dir, sessions, True)
- # cut into buffer len epochs
- if model_type == 'baseline':
- feature = extract_baseline_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
- elif model_type == 'riemann':
- feature = extract_riemann_feature(model, raw, step=0.1, buffer_length=config_info['buffer_length'])
- else:
- raise ValueError
- # initiate hmm model
- hmm_model = HMMClassifier(model[-1], n_iter=100)
- hmm_model.fit(feature)
- # decode
- log_probs, state_seqs = hmm_model.decode(feature)
- plt.figure()
- plt.plot(state_seqs)
- # save transmat
- np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
- plt.show()
|