2
0

utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import numpy as np
  2. import itertools
  3. from datetime import datetime
  4. import joblib
  5. from sklearn.model_selection import KFold
  6. from sklearn.metrics import roc_auc_score
  7. from mne import baseline
  8. import logging
  9. import os
  10. logger = logging.getLogger(__name__)
  11. def event_to_stim_channel(events, time_length, trial_length=None, start_ind=0):
  12. x = np.zeros(time_length, dtype=np.int32)
  13. if trial_length is not None:
  14. for i in range(0, len(events)):
  15. ind = events[i, 0] - start_ind
  16. x[ind:ind + trial_length] = events[i, 2]
  17. else:
  18. for i in range(0, len(events) - 1):
  19. ind_start = events[i, 0] - start_ind
  20. ind_end = events[i + 1, 0] - start_ind
  21. x[ind_start:ind_end] = events[i, 2]
  22. return x
  23. def count_transmat_by_events(events):
  24. y = events[:, -1]
  25. classes = np.unique(y)
  26. classes_ind = {c: i for i, c in enumerate(classes)}
  27. transmat_prior = np.zeros((len(classes), len(classes)))
  28. for i in range(len(y) - 1):
  29. transmat_prior[classes_ind[y[i]], classes_ind[y[i + 1]]] += 1
  30. # normalize
  31. transmat_prior /= np.sum(transmat_prior, axis=1, keepdims=True)
  32. return transmat_prior
  33. def model_saver(model, model_path, model_type, subject_id, event_id):
  34. # event list should be sorted by class label
  35. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  36. # Extract the keys in the sorted order and store them in a list
  37. sorted_events = [item[0] for item in sorted_events]
  38. try:
  39. os.mkdir(os.path.join(model_path, subject_id))
  40. except FileExistsError:
  41. pass
  42. now = datetime.now()
  43. classes = '+'.join(sorted_events)
  44. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  45. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  46. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  47. def parse_model_type(model_path):
  48. model_path = os.path.normpath(model_path)
  49. file_name = model_path.split(os.sep)[-1]
  50. model_type, events, _ = file_name.split('_')
  51. events = events.split('+')
  52. return model_type.lower(), events
  53. def event_metric(event_true, event_pred, fs, hit_time_range=(-0.5, 2), ignore_event=(0,), f_beta=1.):
  54. """评价单试次f_alpha score
  55. Args:
  56. event_true:
  57. event_pred:
  58. fs:
  59. hit_time_range (tuple):
  60. ignore_event (tuple): ignore certain events
  61. f_beta (float): f_(alpha) score
  62. Return:
  63. f_beta score (float): f_alpha score
  64. """
  65. event_true = event_true.copy()[np.logical_not(np.isin(event_true[:, 2], ignore_event))]
  66. event_pred = event_pred.copy()[np.logical_not(np.isin(event_pred[:, 2], ignore_event))]
  67. true_idx = 0
  68. pred_idx = 0
  69. correct_count = 0
  70. hit_time_range = (int(fs * hit_time_range[0]), int(fs * hit_time_range[1]))
  71. while true_idx < len(event_true) and pred_idx < len(event_pred):
  72. if event_true[true_idx, 0] + hit_time_range[0] <= event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[1]:
  73. if event_true[true_idx, 2] == event_pred[pred_idx, 2]:
  74. correct_count += 1
  75. true_idx += 1
  76. pred_idx += 1
  77. else:
  78. pred_idx += 1
  79. elif event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[0]:
  80. pred_idx += 1
  81. else:
  82. true_idx += 1
  83. if len(event_pred) > 0:
  84. precision = correct_count / len(event_pred)
  85. else:
  86. precision = 0.
  87. recall = correct_count / len(event_true)
  88. if f_beta ** 2 * precision + recall > 0:
  89. fbeta_score = (1 + f_beta ** 2) * (precision * recall) / (f_beta ** 2 * precision + recall)
  90. else:
  91. fbeta_score = 0.
  92. return precision, recall, fbeta_score
  93. def cut_epochs(t, data, timestamps):
  94. """
  95. cutting raw data into epochs
  96. :param t: tuple (start, end, samplerate)
  97. :param data: ndarray (..., n_times), the last dimension should be the times
  98. :param timestamps: list of timestamps
  99. :return: ndarray (n_epochs, ... , n_times), the first dimension be the epochs
  100. """
  101. timestamps = np.array(timestamps)
  102. start = timestamps + int(t[0] * t[2])
  103. end = timestamps + int(t[1] * t[2])
  104. # do boundary check
  105. if start[0] < 0:
  106. start = start[1:]
  107. end = end[1:]
  108. if end[-1] > data.shape[-1]:
  109. start = start[:-1]
  110. end = end[:-1]
  111. epochs = np.stack([data[..., s:e] for s, e in zip(start, end)], axis=0)
  112. return epochs
  113. def apply_baseline(t, data, mode='mean'):
  114. """
  115. Simple wrapper of mne rescale function
  116. :param t: tuple (start, end, samplerate)
  117. :param data: ndarray of any shape with axis=-1 the time axis
  118. :param mode: 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio'
  119. refer to mne.baseline.rescale
  120. :return: ndarray
  121. """
  122. start, end, samplerate = t
  123. base = (start, 0)
  124. times = np.linspace(start, end, data.shape[-1])
  125. data = baseline.rescale(data, times, baseline=base, mode=mode, verbose=False)
  126. return data
  127. def product_dict(**kwargs):
  128. keys = kwargs.keys()
  129. vals = kwargs.values()
  130. for instance in itertools.product(*vals):
  131. yield dict(zip(keys, instance))
  132. def param_search(model_func, X, y, params: dict, random_state=123):
  133. """
  134. :param model_func: model builder
  135. :param X: ndarray (n_trials, n_channels, n_times)
  136. :param y: ndarray (n_trials, )
  137. :param params: dict of params, key is param name and value is search range
  138. :param random_state:
  139. :return:
  140. """
  141. kfold = KFold(n_splits=10, shuffle=True, random_state=random_state)
  142. best_auc = -1
  143. best_param = None
  144. for p_dict in product_dict(**params):
  145. model = model_func(**p_dict)
  146. n_classes = len(np.unique(y))
  147. y_pred = np.zeros((len(y), n_classes))
  148. for train_idx, test_idx in kfold.split(X):
  149. X_train, y_train = X[train_idx], y[train_idx]
  150. X_test = X[test_idx]
  151. model.fit(X_train, y_train)
  152. y_pred[test_idx] = model.predict_proba(X_test)
  153. auc = multiclass_auc_score(y, y_pred, n_classes=n_classes)
  154. # update
  155. if auc > best_auc:
  156. best_param = p_dict
  157. best_auc = auc
  158. # print each steps
  159. logger.debug(f'Current: {p_dict}, {auc}; Best: {best_param}, {best_auc}')
  160. return best_auc, best_param
  161. def multiclass_auc_score(y_true, prob, n_classes=None):
  162. if n_classes is None:
  163. n_classes = len(np.unique(y_true))
  164. if n_classes > 2:
  165. auc = roc_auc_score(y_true, prob, multi_class='ovr')
  166. elif n_classes == 2:
  167. auc = roc_auc_score(y_true, prob[:, 1])
  168. else:
  169. raise ValueError
  170. return auc
  171. def reref(data, method):
  172. data = data.copy()
  173. if method == 'average':
  174. data -= data.mean(axis=0)
  175. return data
  176. elif method == 'bipolar':
  177. # neo specific
  178. anode = data[[0, 1, 2, 3, 7, 6, 5]]
  179. cathode = data[[1, 2, 3, 7, 6, 5, 4]]
  180. return anode - cathode
  181. elif method == 'monopolar':
  182. return data
  183. else:
  184. raise ValueError(f'Rereference method unacceptable, got {str(method)}, expect "monopolar" or "average" or "bipolar"')