import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler

from mne.decoding import Vectorizer


class ChannelScaler(BaseEstimator, TransformerMixin):
    def __init__(self, norm_axis=(0, 2)):
        self.channel_mean_ = None
        self.channel_std_ = None
        self.norm_axis=norm_axis

    def fit(self, X, y=None):
        '''

        :param X: 3d array with shape (n_epochs, n_channels, n_times)
        :param y:
        :return:
        '''
        self.channel_mean_ = np.mean(X, axis=self.norm_axis, keepdims=True)
        self.channel_std_ = np.std(X, axis=self.norm_axis, keepdims=True)
        return self

    def transform(self, X, y=None):
        X = X.copy()
        X -= self.channel_mean_
        X /= self.channel_std_
        return X


def baseline_model(C=1.):
    return make_pipeline(
        Vectorizer(),
        StandardScaler(),
        LogisticRegression(C=C)
    )