utils.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import numpy as np
  2. import itertools
  3. from datetime import datetime
  4. import joblib
  5. from sklearn.model_selection import KFold
  6. from sklearn.metrics import roc_auc_score
  7. import logging
  8. import os
  9. logger = logging.getLogger(__name__)
  10. def event_to_stim_channel(events, time_length, trial_length=None, start_ind=0):
  11. x = np.zeros(time_length, dtype=np.int32)
  12. if trial_length is not None:
  13. for i in range(0, len(events)):
  14. ind = events[i, 0] - start_ind
  15. x[ind:ind + trial_length] = events[i, 2]
  16. else:
  17. for i in range(0, len(events) - 1):
  18. ind_start = events[i, 0] - start_ind
  19. ind_end = events[i + 1, 0] - start_ind
  20. x[ind_start:ind_end] = events[i, 2]
  21. return x
  22. def count_transmat_by_events(events):
  23. y = events[:, -1]
  24. classes = np.unique(y)
  25. classes_ind = {c: i for i, c in enumerate(classes)}
  26. transmat_prior = np.zeros((len(classes), len(classes)))
  27. for i in range(len(y) - 1):
  28. transmat_prior[classes_ind[y[i]], classes_ind[y[i + 1]]] += 1
  29. # normalize
  30. transmat_prior /= np.sum(transmat_prior, axis=1, keepdims=True)
  31. return transmat_prior
  32. def model_saver(model, model_path, model_type, subject_id, event_id):
  33. # event list should be sorted by class label
  34. sorted_events = sorted(event_id.items(), key=lambda item: item[1])
  35. # Extract the keys in the sorted order and store them in a list
  36. sorted_events = [item[0] for item in sorted_events]
  37. try:
  38. os.mkdir(os.path.join(model_path, subject_id))
  39. except FileExistsError:
  40. pass
  41. now = datetime.now()
  42. classes = '+'.join(sorted_events)
  43. date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
  44. model_name = f'{model_type}_{classes}_{date_time_str}.pkl'
  45. joblib.dump(model, os.path.join(model_path, subject_id, model_name))
  46. def parse_model_type(model_path):
  47. model_path = os.path.normpath(model_path)
  48. file_name = model_path.split(os.sep)[-1]
  49. model_type, events, _ = file_name.split('_')
  50. events = events.split('+')
  51. return model_type.lower(), events
  52. def event_metric(event_true, event_pred, fs, hit_time_range=(0, 3), ignore_event=(0,), f_beta=1.):
  53. """评价单试次f_alpha score
  54. Args:
  55. event_true:
  56. event_pred:
  57. fs:
  58. hit_time_range (tuple):
  59. ignore_event (tuple): ignore certain events
  60. f_beta (float): f_(alpha) score
  61. Return:
  62. f_beta score (float): f_alpha score
  63. """
  64. event_true = event_true.copy()[np.logical_not(np.isin(event_true[:, 2], ignore_event))]
  65. event_pred = event_pred.copy()[np.logical_not(np.isin(event_pred[:, 2], ignore_event))]
  66. true_idx = 0
  67. pred_idx = 0
  68. correct_count = 0
  69. hit_time_range = (int(fs * hit_time_range[0]), int(fs * hit_time_range[1]))
  70. while true_idx < len(event_true) and pred_idx < len(event_pred):
  71. if event_true[true_idx, 0] + hit_time_range[0] <= event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[1]:
  72. if event_true[true_idx, 2] == event_pred[pred_idx, 2]:
  73. correct_count += 1
  74. true_idx += 1
  75. pred_idx += 1
  76. else:
  77. pred_idx += 1
  78. elif event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[0]:
  79. pred_idx += 1
  80. else:
  81. true_idx += 1
  82. if len(event_pred) > 0:
  83. precision = correct_count / len(event_pred)
  84. else:
  85. precision = 0.
  86. recall = correct_count / len(event_true)
  87. if f_beta ** 2 * precision + recall > 0:
  88. fbeta_score = (1 + f_beta ** 2) * (precision * recall) / (f_beta ** 2 * precision + recall)
  89. else:
  90. fbeta_score = 0.
  91. return precision, recall, fbeta_score
  92. def cut_epochs(t, data, timestamps):
  93. """
  94. cutting raw data into epochs
  95. :param t: tuple (start, end, samplerate)
  96. :param data: ndarray (..., n_times), the last dimension should be the times
  97. :param timestamps: list of timestamps
  98. :return: ndarray (n_epochs, ... , n_times), the first dimension be the epochs
  99. """
  100. timestamps = np.array(timestamps)
  101. start = timestamps + int(t[0] * t[2])
  102. end = timestamps + int(t[1] * t[2])
  103. # do boundary check
  104. if start[0] < 0:
  105. start = start[1:]
  106. end = end[1:]
  107. if end[-1] > data.shape[-1]:
  108. start = start[:-1]
  109. end = end[:-1]
  110. epochs = np.stack([data[..., s:e] for s, e in zip(start, end)], axis=0)
  111. return epochs
  112. def product_dict(**kwargs):
  113. keys = kwargs.keys()
  114. vals = kwargs.values()
  115. for instance in itertools.product(*vals):
  116. yield dict(zip(keys, instance))
  117. def param_search(model_func, X, y, params: dict, random_state=123):
  118. """
  119. :param model_func: model builder
  120. :param X: ndarray (n_trials, n_channels, n_times)
  121. :param y: ndarray (n_trials, )
  122. :param params: dict of params, key is param name and value is search range
  123. :param random_state:
  124. :return:
  125. """
  126. kfold = KFold(n_splits=10, shuffle=True, random_state=random_state)
  127. best_auc = -1
  128. best_param = None
  129. for p_dict in product_dict(**params):
  130. model = model_func(**p_dict)
  131. n_classes = len(np.unique(y))
  132. y_pred = np.zeros((len(y), n_classes))
  133. for train_idx, test_idx in kfold.split(X):
  134. X_train, y_train = X[train_idx], y[train_idx]
  135. X_test = X[test_idx]
  136. model.fit(X_train, y_train)
  137. y_pred[test_idx] = model.predict_proba(X_test)
  138. auc = multiclass_auc_score(y, y_pred, n_classes=n_classes)
  139. # update
  140. if auc > best_auc:
  141. best_param = p_dict
  142. best_auc = auc
  143. # print each steps
  144. logger.debug(f'Current: {p_dict}, {auc}; Best: {best_param}, {best_auc}')
  145. return best_auc, best_param
  146. def multiclass_auc_score(y_true, prob, n_classes=None):
  147. if n_classes is None:
  148. n_classes = len(np.unique(y_true))
  149. if n_classes > 2:
  150. auc = roc_auc_score(y_true, prob, multi_class='ovr')
  151. elif n_classes == 2:
  152. auc = roc_auc_score(y_true, prob[:, 1])
  153. else:
  154. raise ValueError
  155. return auc
  156. def reref(data, method):
  157. data = data.copy()
  158. if method == 'average':
  159. data -= data.mean(axis=0)
  160. return data
  161. elif method == 'bipolar':
  162. # neo specific
  163. anode = data[[0, 1, 2, 3, 7, 6, 5]]
  164. cathode = data[[1, 2, 3, 7, 6, 5, 4]]
  165. return anode - cathode
  166. elif method == 'monopolar':
  167. return data
  168. else:
  169. raise ValueError(f'Rereference method unacceptable, got {str(method)}, expect "monopolar" or "average" or "bipolar"')