Browse Source

ersd可视化方法优化

dk 1 year ago
parent
commit
8184fcb65b
2 changed files with 12 additions and 7 deletions
  1. 2 2
      backend/bci_core/viz.py
  2. 10 5
      backend/validation.py

+ 2 - 2
backend/bci_core/viz.py

@@ -1,5 +1,5 @@
 import mne
-from mne.time_frequency import psd_array_welch
+from mne.time_frequency import psd_array_multitaper
 import matplotlib.pyplot as plt
 import matplotlib as mpl
 import numpy as np
@@ -82,7 +82,7 @@ def plot_ersd(data, events, sfreq, epoch_time_range, event_id, rest_event=0):
     event_desc = {v:k for k, v in event_id.items()}
     epochs = cut_epochs((*epoch_time_range, sfreq), data, events[:, 0])
 
-    psd, freqs = psd_array_welch(epochs, sfreq, fmin=0, fmax=200)
+    psd, freqs = psd_array_multitaper(epochs, sfreq, fmin=0, fmax=200, bandwidth=15)
 
     mean_psd_rest = psd[events[:, -1] == rest_event].mean(axis=0)
     ersds = []

+ 10 - 5
backend/validation.py

@@ -3,15 +3,18 @@
 测试AUC,
 绘制Confusion matrix, ERSD map
 '''
-import numpy as np
-import joblib
-import mne
-import yaml
 import os
 import argparse
 import logging
+
+import mne
+import yaml
+import joblib
+import numpy as np
 from scipy import signal
 from sklearn.metrics import accuracy_score, f1_score
+import matplotlib.pyplot as plt
+
 from dataloaders import neo
 import bci_core.utils as bci_utils
 import bci_core.viz as bci_viz
@@ -60,7 +63,7 @@ def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
     y = events[:, -1]
     auc = bci_utils.multiclass_auc_score(y, prob)
     accu = accuracy_score(y, y_pred)
-    f1 = f1_score(y, y_pred, pos_label=np.max(y))
+    f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')
     # confusion matrix
     fig_conf = bci_viz.plot_confusion_matrix(y, y_pred)
     return (auc, accu, f1), fig_conf
@@ -130,3 +133,5 @@ if __name__ == '__main__':
     
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_conf.savefig(os.path.join(data_dir, 'confusion_matrix.pdf'))   
+
+    plt.show()