Browse Source

算法改进:调参目标换为选择解码准确率最高的情况,避免出现高auc低准确率情况

dk 8 months ago
parent
commit
79ede5a6b4
3 changed files with 11 additions and 13 deletions
  1. 9 11
      backend/bci_core/utils.py
  2. 1 1
      backend/training.py
  3. 1 1
      backend/validation.py

+ 9 - 11
backend/bci_core/utils.py

@@ -3,7 +3,7 @@ import itertools
 from datetime import datetime
 import joblib
 from sklearn.model_selection import KFold
-from sklearn.metrics import roc_auc_score
+from sklearn.metrics import roc_auc_score, f1_score
 from mne import baseline
 import logging
 import os
@@ -166,31 +166,29 @@ def param_search(model_func, X, y, params: dict, random_state=123):
     """
     kfold = KFold(n_splits=10, shuffle=True, random_state=random_state)
 
-    best_auc = -1
+    best_metric = -1
     best_param = None
     for p_dict in product_dict(**params):
         model = model_func(**p_dict)
 
-        n_classes = len(np.unique(y))
-
-        y_pred = np.zeros((len(y), n_classes))
+        y_pred = np.zeros((len(y),))
 
         for train_idx, test_idx in kfold.split(X):
             X_train, y_train = X[train_idx], y[train_idx]
             X_test = X[test_idx]
             model.fit(X_train, y_train)
-            y_pred[test_idx] = model.predict_proba(X_test)
-        auc = multiclass_auc_score(y, y_pred, n_classes=n_classes)
+            y_pred[test_idx] = model.predict(X_test)
+        f1 = f1_score(y, y_pred, average='macro')
 
         # update
-        if auc > best_auc:
+        if f1 > best_metric:
             best_param = p_dict
-            best_auc = auc
+            best_metric = f1
         
         # print each steps
-        logger.debug(f'Current: {p_dict}, {auc}; Best: {best_param}, {best_auc}')
+        logger.debug(f'Current: {p_dict}, {f1}; Best: {best_param}, {best_metric}')
 
-    return best_auc, best_param
+    return best_metric, best_param
 
 
 def multiclass_auc_score(y_true, prob, n_classes=None):

+ 1 - 1
backend/training.py

@@ -73,7 +73,7 @@ def _param_search(model, raw, duration, events):
 
     best_auc, best_param = bci_utils.param_search(LogisticRegression, X_embed, y, param)
 
-    logging.info(f'Best parameter: {best_param}, best auc {best_auc}')
+    logging.info(f'Best parameter: {best_param}, best f1 {best_auc}')
 
     # train and dump best model
     model_for_train = LogisticRegression(**best_param)

+ 1 - 1
backend/validation.py

@@ -59,7 +59,7 @@ def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
     y = events[:, -1]
     auc = bci_utils.multiclass_auc_score(y, prob)
     accu = accuracy_score(y, y_pred)
-    f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')
+    f1 = f1_score(y, y_pred, average='macro')
     # confusion matrix
     fig_conf = bci_viz.plot_confusion_matrix(y, y_pred)
     return (auc, accu, f1), fig_conf