2
0

utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import numpy as np
  2. import itertools
  3. from sklearn.model_selection import KFold
  4. from sklearn.metrics import roc_auc_score
  5. import logging
  6. import os
  7. logger = logging.getLogger(__name__)
  8. def parse_model_type(model_path):
  9. model_path = os.path.normpath(model_path)
  10. file_name = model_path.split(os.sep)[-1]
  11. model_type, events, _ = file_name.split('_')
  12. events = events.split('+')
  13. return model_type.lower(), events
  14. def event_metric(event_true, event_pred, fs, hit_time_range=(0, 3), ignore_event=(0,), f_beta=1.):
  15. """评价单试次f_alpha score
  16. Args:
  17. event_true:
  18. event_pred:
  19. fs:
  20. hit_time_range (tuple):
  21. ignore_event (tuple): ignore certain events
  22. f_beta (float): f_(alpha) score
  23. Return:
  24. f_beta score (float): f_alpha score
  25. """
  26. event_true = event_true.copy()[np.logical_not(np.isin(event_true[:, 2], ignore_event))]
  27. event_pred = event_pred.copy()[np.logical_not(np.isin(event_pred[:, 2], ignore_event))]
  28. true_idx = 0
  29. pred_idx = 0
  30. correct_count = 0
  31. hit_time_range = (int(fs * hit_time_range[0]), int(fs * hit_time_range[1]))
  32. while true_idx < len(event_true) and pred_idx < len(event_pred):
  33. if event_true[true_idx, 0] + hit_time_range[0] <= event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[1]:
  34. if event_true[true_idx, 2] == event_pred[pred_idx, 2]:
  35. correct_count += 1
  36. true_idx += 1
  37. pred_idx += 1
  38. else:
  39. pred_idx += 1
  40. elif event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[0]:
  41. pred_idx += 1
  42. else:
  43. true_idx += 1
  44. precision = correct_count / len(event_pred)
  45. recall = correct_count / len(event_true)
  46. fbeta_score = (1 + f_beta ** 2) * (precision * recall) / (f_beta ** 2 * precision + recall)
  47. return precision, recall, fbeta_score
  48. def cut_epochs(t, data, timestamps):
  49. """
  50. cutting raw data into epochs
  51. :param t: tuple (start, end, samplerate)
  52. :param data: ndarray (..., n_times), the last dimension should be the times
  53. :param timestamps: list of timestamps
  54. :return: ndarray (n_epochs, ... , n_times), the first dimension be the epochs
  55. """
  56. timestamps = np.array(timestamps)
  57. start = timestamps + int(t[0] * t[2])
  58. end = timestamps + int(t[1] * t[2])
  59. # do boundary check
  60. if start[0] < 0:
  61. start = start[1:]
  62. end = end[1:]
  63. if end[-1] > data.shape[-1]:
  64. start = start[:-1]
  65. end = end[:-1]
  66. epochs = np.stack([data[..., s:e] for s, e in zip(start, end)], axis=0)
  67. return epochs
  68. def product_dict(**kwargs):
  69. keys = kwargs.keys()
  70. vals = kwargs.values()
  71. for instance in itertools.product(*vals):
  72. yield dict(zip(keys, instance))
  73. def param_search(model_func, X, y, params: dict, random_state=123):
  74. """
  75. :param model_func: model builder
  76. :param X: ndarray (n_trials, n_channels, n_times)
  77. :param y: ndarray (n_trials, )
  78. :param params: dict of params, key is param name and value is search range
  79. :param random_state:
  80. :return:
  81. """
  82. kfold = KFold(n_splits=10, shuffle=True, random_state=random_state)
  83. best_auc = -1
  84. best_param = None
  85. for p_dict in product_dict(**params):
  86. model = model_func(**p_dict)
  87. n_classes = len(np.unique(y))
  88. y_pred = np.zeros((len(y), n_classes))
  89. for train_idx, test_idx in kfold.split(X):
  90. X_train, y_train = X[train_idx], y[train_idx]
  91. X_test = X[test_idx]
  92. model.fit(X_train, y_train)
  93. y_pred[test_idx] = model.predict_proba(X_test)
  94. auc = multiclass_auc_score(y, y_pred, n_classes=n_classes)
  95. # update
  96. if auc > best_auc:
  97. best_param = p_dict
  98. best_auc = auc
  99. # print each steps
  100. logger.debug(f'Current: {p_dict}, {auc}; Best: {best_param}, {best_auc}')
  101. return best_auc, best_param
  102. def multiclass_auc_score(y_true, prob, n_classes=None):
  103. if n_classes is None:
  104. n_classes = len(np.unique(y_true))
  105. if n_classes > 2:
  106. auc = roc_auc_score(y_true, prob, multi_class='ovr')
  107. elif n_classes == 2:
  108. auc = roc_auc_score(y_true, prob[:, 1])
  109. else:
  110. raise ValueError
  111. return auc