123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import numpy as np
- import itertools
- from datetime import datetime
- import joblib
- from sklearn.model_selection import KFold
- from sklearn.metrics import roc_auc_score
- import logging
- import os
- logger = logging.getLogger(__name__)
- def event_to_stim_channel(events, time_length, trial_length=None, start_ind=0):
- x = np.zeros(time_length, dtype=np.int32)
- if trial_length is not None:
- for i in range(0, len(events)):
- ind = events[i, 0] - start_ind
- x[ind:ind + trial_length] = events[i, 2]
- else:
- for i in range(0, len(events) - 1):
- ind_start = events[i, 0] - start_ind
- ind_end = events[i + 1, 0] - start_ind
- x[ind_start:ind_end] = events[i, 2]
- return x
- def count_transmat_by_events(events):
- y = events[:, -1]
- classes = np.unique(y)
- classes_ind = {c: i for i, c in enumerate(classes)}
- transmat_prior = np.zeros((len(classes), len(classes)))
- for i in range(len(y) - 1):
- transmat_prior[classes_ind[y[i]], classes_ind[y[i + 1]]] += 1
- # normalize
- transmat_prior /= np.sum(transmat_prior, axis=1, keepdims=True)
- return transmat_prior
- def model_saver(model, model_path, model_type, subject_id, event_id):
- # event list should be sorted by class label
- sorted_events = sorted(event_id.items(), key=lambda item: item[1])
- # Extract the keys in the sorted order and store them in a list
- sorted_events = [item[0] for item in sorted_events]
- try:
- os.mkdir(os.path.join(model_path, subject_id))
- except FileExistsError:
- pass
- now = datetime.now()
- classes = '+'.join(sorted_events)
- date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
- model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
- joblib.dump(model, os.path.join(model_path, subject_id, model_name))
- def parse_model_type(model_path):
- model_path = os.path.normpath(model_path)
- file_name = model_path.split(os.sep)[-1]
- model_type, events, _ = file_name.split('_')
- events = events.split('+')
- return model_type.lower(), events
- def event_metric(event_true, event_pred, fs, hit_time_range=(0, 3), ignore_event=(0,), f_beta=1.):
- """评价单试次f_alpha score
- Args:
- event_true:
- event_pred:
- fs:
- hit_time_range (tuple):
- ignore_event (tuple): ignore certain events
- f_beta (float): f_(alpha) score
- Return:
- f_beta score (float): f_alpha score
- """
- event_true = event_true.copy()[np.logical_not(np.isin(event_true[:, 2], ignore_event))]
- event_pred = event_pred.copy()[np.logical_not(np.isin(event_pred[:, 2], ignore_event))]
- true_idx = 0
- pred_idx = 0
- correct_count = 0
- hit_time_range = (int(fs * hit_time_range[0]), int(fs * hit_time_range[1]))
- while true_idx < len(event_true) and pred_idx < len(event_pred):
- if event_true[true_idx, 0] + hit_time_range[0] <= event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[1]:
- if event_true[true_idx, 2] == event_pred[pred_idx, 2]:
- correct_count += 1
- true_idx += 1
- pred_idx += 1
- else:
- pred_idx += 1
- elif event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[0]:
- pred_idx += 1
- else:
- true_idx += 1
-
- if len(event_pred) > 0:
- precision = correct_count / len(event_pred)
- else:
- precision = 0.
-
- recall = correct_count / len(event_true)
- if f_beta ** 2 * precision + recall > 0:
- fbeta_score = (1 + f_beta ** 2) * (precision * recall) / (f_beta ** 2 * precision + recall)
- else:
- fbeta_score = 0.
- return precision, recall, fbeta_score
- def cut_epochs(t, data, timestamps):
- """
- cutting raw data into epochs
- :param t: tuple (start, end, samplerate)
- :param data: ndarray (..., n_times), the last dimension should be the times
- :param timestamps: list of timestamps
- :return: ndarray (n_epochs, ... , n_times), the first dimension be the epochs
- """
- timestamps = np.array(timestamps)
- start = timestamps + int(t[0] * t[2])
- end = timestamps + int(t[1] * t[2])
- # do boundary check
- if start[0] < 0:
- start = start[1:]
- end = end[1:]
- if end[-1] > data.shape[-1]:
- start = start[:-1]
- end = end[:-1]
- epochs = np.stack([data[..., s:e] for s, e in zip(start, end)], axis=0)
- return epochs
- def product_dict(**kwargs):
- keys = kwargs.keys()
- vals = kwargs.values()
- for instance in itertools.product(*vals):
- yield dict(zip(keys, instance))
- def param_search(model_func, X, y, params: dict, random_state=123):
- """
- :param model_func: model builder
- :param X: ndarray (n_trials, n_channels, n_times)
- :param y: ndarray (n_trials, )
- :param params: dict of params, key is param name and value is search range
- :param random_state:
- :return:
- """
- kfold = KFold(n_splits=10, shuffle=True, random_state=random_state)
- best_auc = -1
- best_param = None
- for p_dict in product_dict(**params):
- model = model_func(**p_dict)
- n_classes = len(np.unique(y))
- y_pred = np.zeros((len(y), n_classes))
- for train_idx, test_idx in kfold.split(X):
- X_train, y_train = X[train_idx], y[train_idx]
- X_test = X[test_idx]
- model.fit(X_train, y_train)
- y_pred[test_idx] = model.predict_proba(X_test)
- auc = multiclass_auc_score(y, y_pred, n_classes=n_classes)
- # update
- if auc > best_auc:
- best_param = p_dict
- best_auc = auc
-
- # print each steps
- logger.debug(f'Current: {p_dict}, {auc}; Best: {best_param}, {best_auc}')
- return best_auc, best_param
- def multiclass_auc_score(y_true, prob, n_classes=None):
- if n_classes is None:
- n_classes = len(np.unique(y_true))
- if n_classes > 2:
- auc = roc_auc_score(y_true, prob, multi_class='ovr')
- elif n_classes == 2:
- auc = roc_auc_score(y_true, prob[:, 1])
- else:
- raise ValueError
- return auc
- def reref(data, method):
- data = data.copy()
- if method == 'average':
- data -= data.mean(axis=0)
- return data
- elif method == 'bipolar':
- # neo specific
- anode = data[[0, 1, 2, 3, 7, 6, 5]]
- cathode = data[[1, 2, 3, 7, 6, 5, 4]]
- return anode - cathode
- elif method == 'monopolar':
- return data
- else:
- raise ValueError(f'Rereference method unacceptable, got {str(method)}, expect "monopolar" or "average" or "bipolar"')
-
|