import logging
import joblib
import os
from datetime import datetime
import yaml
import argparse

import mne
import numpy as np
from scipy import signal
from pyriemann.estimation import BlockCovariances

import bci_core.feature_extractors as feature_extractors
import bci_core.utils as bci_utils
import bci_core.model as bci_model
from dataloaders import neo
from settings.config import settings


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

config_info = settings.CONFIG_INFO


def parse_args():
    parser = argparse.ArgumentParser(
        description='Model validation'
    )
    parser.add_argument(
        '--subj',
        dest='subj',
        help='Subject name',
        default=None,
        type=str
    )
    parser.add_argument(
        '--model-type',
        dest='model_type',
        default='baseline',
        type=str
    )
    return parser.parse_args()


def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
    """
    """
    events, _ = mne.events_from_annotations(raw, event_id=event_id)
    if model_type.lower() == 'baseline':
        model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs)
    elif model_type.lower() == 'riemann':
        model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs)
    else:
        raise NotImplementedError
    return model


def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
    fs = raw.info['sfreq']
    n_ch = len(raw.ch_names)
    feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands)
    filtered_data = feat_extractor.transform(raw.get_data())
    X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
    y = events[:, -1]
    
    scaler = bci_model.ChannelScaler()
    X = scaler.fit_transform(X)

    # compute covariance
    lfb_dim = len(lf_bands) * n_ch
    hgs_dim = len(hg_bands) * n_ch
    cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
    X_cov = cov_model.fit_transform(X)

    param = {'C': np.logspace(-5, 4, 10)}

    best_auc, best_param = bci_utils.param_search(bci_model.riemann_model, X_cov, y, param)

    logging.info(f'Best parameter: {best_param}, best auc {best_auc}')

    # train and dump best model
    model_to_train = bci_model.riemann_model(**best_param)
    model_to_train.fit(X_cov, y)
    return [feat_extractor, scaler, cov_model, model_to_train]


def _train_baseline_model(raw, events, duration=1., freqs=(20, 150, 15)):
    fs = raw.info['sfreq']
    freqs = np.arange(*freqs)
    filterbank_extractor = feature_extractors.FilterbankExtractor(fs, freqs)
    filter_bank_data = filterbank_extractor.transform(raw.get_data())

    filter_bank_epoch = bci_utils.cut_epochs((0, duration, fs), filter_bank_data, events[:, 0])

    # downsampling to 10 Hz
    # decim 2 times, to 100Hz
    decimate_rate = np.sqrt(fs / 10).astype(np.int16)
    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
    # to 10Hz
    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
    X = filter_bank_epoch
    y = events[:, -1]

    best_auc, best_param = bci_utils.param_search(bci_model.baseline_model, X, y, {'C': np.logspace(-5, 4, 10)})
    logger.info(f'Best parameter: {best_param}, best auc {best_auc}')

    model_to_train = bci_model.baseline_model(**best_param)
    model_to_train.fit(X, y)
    return filterbank_extractor, model_to_train


def model_saver(model, model_path, model_type, subject_id, event_id):
    # event list should be sorted by class label
    sorted_events = sorted(event_id.items(), key=lambda item: item[1])
    # Extract the keys in the sorted order and store them in a list
    sorted_events = [item[0] for item in sorted_events]

    try:
        os.mkdir(os.path.join(model_path, subject_id))
    except FileExistsError:
        pass

    now = datetime.now()
    classes = '+'.join(sorted_events)
    date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
    model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
    joblib.dump(model, os.path.join(model_path, subject_id, model_name))


if __name__ == '__main__':
    args = parse_args()
    subj_name = args.subj
    model_type = args.model_type

    data_dir = f'./data/{subj_name}/'
    model_dir = './static/models/'

    with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
        model_config = yaml.safe_load(f)[model_type]

    with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
        info = yaml.safe_load(f)
    sessions = info['sessions']
    
    trial_duration = config_info['buffer_length']
    # preprocess raw
    raw, event_id = neo.raw_loader(data_dir, sessions, upsampled_epoch_length=trial_duration, ori_epoch_length=5)

    # train model
    model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
    
    # save
    model_saver(model, model_dir, model_type, subj_name, event_id)