from scipy import io as sio
import mne
import numpy as np
from .utils import upsample_events


# loader for test data
def raw_preprocessing(data_file, finger_model='cylinder', epoch_time=1., fs=1000):
    data = sio.loadmat(data_file, simplify_cells=True, squeeze_me=True)
    # to double
    raw = data['data'].astype(np.float64).T * 0.0298 * 1e-6  # to V
    stim_events = data['stim'].astype(np.float64)
    # deal with line noise
    raw = mne.filter.notch_filter(raw, fs, [60, 120, 180], trans_bandwidth=3, verbose=False)

    events = extract_events(stim_events, fs)
    # upsampling 
    events = upsample_events(events, int(epoch_time * fs))

    info = mne.create_info([f'ch_{i}' for i in range(raw.shape[0])], sfreq=fs, ch_types='ecog')

    # build raw
    raw = mne.io.RawArray(raw, info)

    annotations = mne.annotations_from_events(events, fs, {1: finger_model, 0: 'rest'})
    raw.set_annotations(annotations)
    return raw


def extract_events(stim_events, fs=1000.):
    diff_stim = np.diff(stim_events)

    shift_idx = int(0.5 * fs)  # shift by 500 ms, compensate for reaction time

    # hand only
    onsets = np.flatnonzero(diff_stim == 12) + shift_idx
    
    offsets = np.flatnonzero(diff_stim == -12)

    # handle cut
    if len(onsets) != len(offsets):
        # cut first trial
        if offsets[0] <= onsets[0]:
            offsets = offsets[1:]
        # cut last trial
        else:
            onsets = onsets[:-1]

    rest_onset = offsets + shift_idx
    if len(np.unique(offsets - onsets)) > 1:
        raise ValueError('Unequal trial length?')
    trial_length = (offsets - onsets)[0]

    # build events
    events = np.zeros((len(onsets) * 2, 3), dtype=np.int64)
    events[::2, 0] = onsets
    events[:, 1] = trial_length
    events[1::2, 0] = rest_onset
    events[::2, 2] = 1
    return events