utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import numpy as np
  2. import itertools
  3. from sklearn.model_selection import KFold
  4. from sklearn.metrics import roc_auc_score
  5. import logging
  6. import os
  7. def parse_model_type(model_path):
  8. model_path = os.path.normpath(model_path)
  9. file_name = model_path.split(os.sep)[-1]
  10. model_type, events, _ = file_name.split('_')
  11. events = events.split('+')
  12. return model_type.lower(), events
  13. def event_metric(event_true, event_pred, fs, hit_time_range=(0, 3), ignore_event=(0,), f_beta=1.):
  14. """评价单试次f_alpha score
  15. Args:
  16. event_true:
  17. event_pred:
  18. fs:
  19. hit_time_range (tuple):
  20. ignore_event (tuple): ignore certain events
  21. f_beta (float): f_(alpha) score
  22. Return:
  23. f_beta score (float): f_alpha score
  24. """
  25. event_true = event_true.copy()[np.logical_not(np.isin(event_true[:, 2], ignore_event))]
  26. event_pred = event_pred.copy()[np.logical_not(np.isin(event_pred[:, 2], ignore_event))]
  27. true_idx = 0
  28. pred_idx = 0
  29. correct_count = 0
  30. hit_time_range = (int(fs * hit_time_range[0]), int(fs * hit_time_range[1]))
  31. while true_idx < len(event_true) and pred_idx < len(event_pred):
  32. if event_true[true_idx, 0] + hit_time_range[0] <= event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[1]:
  33. if event_true[true_idx, 2] == event_pred[pred_idx, 2]:
  34. correct_count += 1
  35. true_idx += 1
  36. pred_idx += 1
  37. else:
  38. pred_idx += 1
  39. elif event_pred[pred_idx, 0] < event_true[true_idx, 0] + hit_time_range[0]:
  40. pred_idx += 1
  41. else:
  42. true_idx += 1
  43. precision = correct_count / len(event_pred)
  44. recall = correct_count / len(event_true)
  45. fbeta_score = (1 + f_beta ** 2) * (precision * recall) / (f_beta ** 2 * precision + recall)
  46. return precision, recall, fbeta_score
  47. def cut_epochs(t, data, timestamps):
  48. """
  49. cutting raw data into epochs
  50. :param t: tuple (start, end, samplerate)
  51. :param data: ndarray (..., n_times), the last dimension should be the times
  52. :param timestamps: list of timestamps
  53. :return: ndarray (n_epochs, ... , n_times), the first dimension be the epochs
  54. """
  55. timestamps = np.array(timestamps)
  56. start = timestamps + int(t[0] * t[2])
  57. end = timestamps + int(t[1] * t[2])
  58. # do boundary check
  59. if start[0] < 0:
  60. start = start[1:]
  61. end = end[1:]
  62. if end[-1] > data.shape[-1]:
  63. start = start[:-1]
  64. end = end[:-1]
  65. epochs = np.stack([data[..., s:e] for s, e in zip(start, end)], axis=0)
  66. return epochs
  67. def product_dict(**kwargs):
  68. keys = kwargs.keys()
  69. vals = kwargs.values()
  70. for instance in itertools.product(*vals):
  71. yield dict(zip(keys, instance))
  72. def param_search(model_func, X, y, params: dict, random_state=123):
  73. """
  74. :param model_func: model builder
  75. :param X: ndarray (n_trials, n_channels, n_times)
  76. :param y: ndarray (n_trials, )
  77. :param params: dict of params, key is param name and value is search range
  78. :param random_state:
  79. :return:
  80. """
  81. kfold = KFold(n_splits=10, shuffle=True, random_state=random_state)
  82. best_auc = -1
  83. best_param = None
  84. for p_dict in product_dict(**params):
  85. model = model_func(**p_dict)
  86. n_classes = len(np.unique(y))
  87. y_pred = np.zeros((len(y), n_classes))
  88. for train_idx, test_idx in kfold.split(X):
  89. X_train, y_train = X[train_idx], y[train_idx]
  90. X_test = X[test_idx]
  91. model.fit(X_train, y_train)
  92. y_pred[test_idx] = model.predict_proba(X_test)
  93. if n_classes == 3:
  94. auc = roc_auc_score(y, y_pred, multi_class='ovr')
  95. elif n_classes == 2:
  96. auc = roc_auc_score(y, y_pred[:, 1])
  97. else:
  98. raise ValueError
  99. # update
  100. if auc > best_auc:
  101. best_param = p_dict
  102. best_auc = auc
  103. # print each steps
  104. logging.debug(f'Current: {p_dict}, {auc}; Best: {best_param}, {best_auc}')
  105. return best_auc, best_param