'''
Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
'''
import os
import argparse

from hmmlearn import hmm
import numpy as np
import yaml
import joblib
from scipy import signal
import matplotlib.pyplot as plt

from dataloaders import neo
import bci_core.utils as bci_utils
from settings.config import settings


config_info = settings.CONFIG_INFO

class HMMClassifier(hmm.BaseHMM):
    # TODO: 如何绕过hmmlearn里使用的sklearn.utils.validation.check_array,目前我直接修改了hmmlearn的源码(删除所有的check_array)
    # TODO: 可行的方法是修改模型组织,将特征提取步骤与最终分类器分开,模型只保留最终分类器,这样仅需接收二维特征。
    def __init__(self, emission_model, **kwargs):
        n_components = len(emission_model.classes_)
        super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)

        self.emission_model = emission_model
    
    def _check_and_set_n_features(self, X):
        if X.ndim == 2:  # 
            n_features = X.shape[1]
        elif X.ndim == 3:
            n_features = X.shape[1] * X.shape[2]
        else:
            raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
        if hasattr(self, "n_features"):
            if self.n_features != n_features:
                raise ValueError(
                    f"Unexpected number of dimensions, got {n_features} but "
                    f"expected {self.n_features}")
        else:
            self.n_features = n_features
    
    def _get_n_fit_scalars_per_param(self):
        nc = self.n_components
        return {
            "s": nc,
            "t": nc ** 2}
        
    def _compute_likelihood(self, X):
        p = self.emission_model.predict_proba(X)
        return p


def extract_baseline_feature(model, raw, step):
    fs = raw.info['sfreq']
    feat_extractor, _ = model
    filter_bank_data = feat_extractor.transform(raw.get_data())
    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
    filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
    # decimate
    decimate_rate = np.sqrt(fs / 10).astype(np.int16)
    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
    # to 10Hz
    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
    return filter_bank_epoch


def extract_riemann_feature(model, raw, step):
    fs = raw.info['sfreq']
    feat_extractor, scaler, cov_model, _ = model
    filtered_data = feat_extractor.transform(raw.get_data())
    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
    X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
    X = scaler.transform(X)
    X_cov = cov_model.transform(X)
    return X_cov


def _split_continuous(time_range, step, fs):
    return np.arange(int(time_range[0] * fs), 
                           int(time_range[-1] * fs), 
                           int(step * fs), dtype=np.int64)


def parse_args():
    parser = argparse.ArgumentParser(
        description='Model validation'
    )
    parser.add_argument(
        '--subj',
        dest='subj',
        help='Subject name',
        default=None,
        type=str
    )
    parser.add_argument(
        '--state-change-threshold',
        '-scth',
        dest='state_change_threshold',
        help='Threshold for HMM state change',
        default=0.75,
        type=float
    )
    parser.add_argument(
        '--state-trans-prob',
        '-stp',
        dest='state_trans_prob',
        help='Transition probability for HMM state change',
        default=0.8,
        type=float
    )
    parser.add_argument(
        '--model-filename',
        dest='model_filename',
        help='Model filename',
        default=None,
        type=str
    )
    return parser.parse_args()

args = parse_args()
# load model and fit hmm
subj_name = args.subj
model_filename = args.model_filename

data_dir = f'./data/{subj_name}/'
    
model_path = f'./static/models/{subj_name}/{model_filename}'

# load model
model_type, _ = bci_utils.parse_model_type(model_filename)
model = joblib.load(model_path)

with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
    info = yaml.safe_load(f)
sessions = info['hmm_sessions']

raw, event_id = neo.raw_loader(data_dir, sessions, True)

# cut into buffer len epochs
if model_type == 'baseline':
    feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
elif model_type == 'riemann':
    feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
else:
    raise ValueError

# initiate hmm model
hmm_model = HMMClassifier(model[-1], n_iter=100)
hmm_model.fit(feature)

# decode
log_probs, state_seqs = hmm_model.decode(feature)
plt.figure()
plt.plot(state_seqs)

# save transmat
np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)

plt.show()