ソースを参照

Feat: support color settings

dk 1 年間 前
コミット
cf1d8eb251
1 ファイル変更8 行追加4 行削除
  1. 8 4
      backend/bci_core/viz.py

+ 8 - 4
backend/bci_core/viz.py

@@ -120,20 +120,24 @@ def plot_confusion_matrix(y_true, y_pred):
     return disp.figure_
 
 
-def plot_states(time_range, pred_states, ax):
+def plot_states(time_range, pred_states, ax, colors=None):
     classes = np.unique(pred_states)
+    if colors is None:
+        colors = [plt.get_cmap('tab10')(i)[:3] for i in range(len(classes))]
     for i, c in enumerate(classes):
         ax.fill_between(np.linspace(*time_range, len(pred_states)), 0, 1,
-                        where=(pred_states == c), alpha=0.6, color=plt.get_cmap('tab10')(i)[:3])
+                        where=(pred_states == c), alpha=0.6, color=colors[i])
     return ax
 
 
-def plot_state_prob_with_cue(time_range, true_states, pred_probs, ax):
+def plot_state_prob_with_cue(time_range, true_states, pred_probs, ax, colors=None):
     # normalize
     ax.plot(np.linspace(*time_range, len(pred_probs)), pred_probs, color='k')
     # for each class, fill different colors
     classes = np.unique(true_states)
+    if colors is None:
+        colors = [plt.get_cmap('tab10')(i)[:3] for i in range(len(classes))]
     for i, c in enumerate(classes):
         ax.fill_between(np.linspace(*time_range, len(true_states)), 0, 1,
-                        where=(true_states == c), alpha=0.6, color=plt.get_cmap('tab10')(i)[:3])
+                        where=(true_states == c), alpha=0.6, color=colors[i])
     return ax