|
@@ -3,15 +3,18 @@
|
|
|
测试AUC,
|
|
|
绘制Confusion matrix, ERSD map
|
|
|
'''
|
|
|
-import numpy as np
|
|
|
-import joblib
|
|
|
-import mne
|
|
|
-import yaml
|
|
|
import os
|
|
|
import argparse
|
|
|
import logging
|
|
|
+
|
|
|
+import mne
|
|
|
+import yaml
|
|
|
+import joblib
|
|
|
+import numpy as np
|
|
|
from scipy import signal
|
|
|
from sklearn.metrics import accuracy_score, f1_score
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+
|
|
|
from dataloaders import neo
|
|
|
import bci_core.utils as bci_utils
|
|
|
import bci_core.viz as bci_viz
|
|
@@ -60,7 +63,7 @@ def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
|
|
|
y = events[:, -1]
|
|
|
auc = bci_utils.multiclass_auc_score(y, prob)
|
|
|
accu = accuracy_score(y, y_pred)
|
|
|
- f1 = f1_score(y, y_pred, pos_label=np.max(y))
|
|
|
+ f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')
|
|
|
# confusion matrix
|
|
|
fig_conf = bci_viz.plot_confusion_matrix(y, y_pred)
|
|
|
return (auc, accu, f1), fig_conf
|
|
@@ -130,3 +133,5 @@ if __name__ == '__main__':
|
|
|
|
|
|
fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
|
|
|
fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf'))
|
|
|
+
|
|
|
+ plt.show()
|