|
@@ -0,0 +1,161 @@
|
|
|
+'''
|
|
|
+Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
|
|
|
+'''
|
|
|
+import os
|
|
|
+import argparse
|
|
|
+
|
|
|
+from hmmlearn import hmm
|
|
|
+import numpy as np
|
|
|
+import yaml
|
|
|
+import joblib
|
|
|
+from scipy import signal
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+
|
|
|
+from dataloaders import neo
|
|
|
+import bci_core.utils as bci_utils
|
|
|
+from settings.config import settings
|
|
|
+
|
|
|
+
|
|
|
+config_info = settings.CONFIG_INFO
|
|
|
+
|
|
|
+class HMMClassifier(hmm.BaseHMM):
|
|
|
+ # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array)
|
|
|
+ def __init__(self, emission_model, **kwargs):
|
|
|
+ n_components = len(emission_model.classes_)
|
|
|
+ super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
|
|
|
+
|
|
|
+ self.emission_model = emission_model
|
|
|
+
|
|
|
+ def _check_and_set_n_features(self, X):
|
|
|
+ if X.ndim == 2: #
|
|
|
+ n_features = X.shape[1]
|
|
|
+ elif X.ndim == 3:
|
|
|
+ n_features = X.shape[1] * X.shape[2]
|
|
|
+ else:
|
|
|
+ raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
|
|
|
+ if hasattr(self, "n_features"):
|
|
|
+ if self.n_features != n_features:
|
|
|
+ raise ValueError(
|
|
|
+ f"Unexpected number of dimensions, got {n_features} but "
|
|
|
+ f"expected {self.n_features}")
|
|
|
+ else:
|
|
|
+ self.n_features = n_features
|
|
|
+
|
|
|
+ def _get_n_fit_scalars_per_param(self):
|
|
|
+ nc = self.n_components
|
|
|
+ return {
|
|
|
+ "s": nc,
|
|
|
+ "t": nc ** 2}
|
|
|
+
|
|
|
+ def _compute_likelihood(self, X):
|
|
|
+ p = self.emission_model.predict_proba(X)
|
|
|
+ return p
|
|
|
+
|
|
|
+
|
|
|
+def extract_baseline_feature(model, raw, step):
|
|
|
+ fs = raw.info['sfreq']
|
|
|
+ feat_extractor, _ = model
|
|
|
+ filter_bank_data = feat_extractor.transform(raw.get_data())
|
|
|
+ timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
|
|
|
+ filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
|
|
|
+ # decimate
|
|
|
+ decimate_rate = np.sqrt(fs / 10).astype(np.int16)
|
|
|
+ filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
|
|
|
+ # to 10Hz
|
|
|
+ filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
|
|
|
+ return filter_bank_epoch
|
|
|
+
|
|
|
+
|
|
|
+def extract_riemann_feature(model, raw, step):
|
|
|
+ fs = raw.info['sfreq']
|
|
|
+ feat_extractor, scaler, cov_model, _ = model
|
|
|
+ filtered_data = feat_extractor.transform(raw.get_data())
|
|
|
+ timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
|
|
|
+ X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
|
|
|
+ X = scaler.transform(X)
|
|
|
+ X_cov = cov_model.transform(X)
|
|
|
+ return X_cov
|
|
|
+
|
|
|
+
|
|
|
+def _split_continuous(time_range, step, fs):
|
|
|
+ return np.arange(int(time_range[0] * fs),
|
|
|
+ int(time_range[-1] * fs),
|
|
|
+ int(step * fs), dtype=np.int64)
|
|
|
+
|
|
|
+
|
|
|
+def parse_args():
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description='Model validation'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--subj',
|
|
|
+ dest='subj',
|
|
|
+ help='Subject name',
|
|
|
+ default=None,
|
|
|
+ type=str
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--state-change-threshold',
|
|
|
+ '-scth',
|
|
|
+ dest='state_change_threshold',
|
|
|
+ help='Threshold for HMM state change',
|
|
|
+ default=0.75,
|
|
|
+ type=float
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--state-trans-prob',
|
|
|
+ '-stp',
|
|
|
+ dest='state_trans_prob',
|
|
|
+ help='Transition probability for HMM state change',
|
|
|
+ default=0.8,
|
|
|
+ type=float
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--model-filename',
|
|
|
+ dest='model_filename',
|
|
|
+ help='Model filename',
|
|
|
+ default=None,
|
|
|
+ type=str
|
|
|
+ )
|
|
|
+ return parser.parse_args()
|
|
|
+
|
|
|
+args = parse_args()
|
|
|
+# load model and fit hmm
|
|
|
+subj_name = args.subj
|
|
|
+model_filename = args.model_filename
|
|
|
+
|
|
|
+data_dir = f'./data/{subj_name}/'
|
|
|
+
|
|
|
+model_path = f'./static/models/{subj_name}/{model_filename}'
|
|
|
+
|
|
|
+# load model
|
|
|
+model_type, _ = bci_utils.parse_model_type(model_filename)
|
|
|
+model = joblib.load(model_path)
|
|
|
+
|
|
|
+with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
|
|
|
+ info = yaml.safe_load(f)
|
|
|
+sessions = info['hmm_sessions']
|
|
|
+
|
|
|
+raw = neo.raw_loader(data_dir, sessions, True)
|
|
|
+
|
|
|
+# cut into buffer len epochs
|
|
|
+if model_type == 'baseline':
|
|
|
+ feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
|
|
|
+elif model_type == 'riemann':
|
|
|
+ feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
|
|
|
+else:
|
|
|
+ raise ValueError
|
|
|
+
|
|
|
+# initiate hmm model
|
|
|
+hmm_model = HMMClassifier(model[-1], n_iter=100)
|
|
|
+hmm_model.fit(feature)
|
|
|
+
|
|
|
+# decode
|
|
|
+log_probs, state_seqs = hmm_model.decode(feature)
|
|
|
+plt.figure()
|
|
|
+plt.plot(state_seqs)
|
|
|
+
|
|
|
+# save transmat
|
|
|
+np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
|
|
|
+
|
|
|
+plt.show()
|