|
@@ -120,20 +120,24 @@ def plot_confusion_matrix(y_true, y_pred):
|
|
return disp.figure_
|
|
return disp.figure_
|
|
|
|
|
|
|
|
|
|
-def plot_states(time_range, pred_states, ax):
|
|
|
|
|
|
+def plot_states(time_range, pred_states, ax, colors=None):
|
|
classes = np.unique(pred_states)
|
|
classes = np.unique(pred_states)
|
|
|
|
+ if colors is None:
|
|
|
|
+ colors = [plt.get_cmap('tab10')(i)[:3] for i in range(len(classes))]
|
|
for i, c in enumerate(classes):
|
|
for i, c in enumerate(classes):
|
|
ax.fill_between(np.linspace(*time_range, len(pred_states)), 0, 1,
|
|
ax.fill_between(np.linspace(*time_range, len(pred_states)), 0, 1,
|
|
- where=(pred_states == c), alpha=0.6, color=plt.get_cmap('tab10')(i)[:3])
|
|
|
|
|
|
+ where=(pred_states == c), alpha=0.6, color=colors[i])
|
|
return ax
|
|
return ax
|
|
|
|
|
|
|
|
|
|
-def plot_state_prob_with_cue(time_range, true_states, pred_probs, ax):
|
|
|
|
|
|
+def plot_state_prob_with_cue(time_range, true_states, pred_probs, ax, colors=None):
|
|
# normalize
|
|
# normalize
|
|
ax.plot(np.linspace(*time_range, len(pred_probs)), pred_probs, color='k')
|
|
ax.plot(np.linspace(*time_range, len(pred_probs)), pred_probs, color='k')
|
|
# for each class, fill different colors
|
|
# for each class, fill different colors
|
|
classes = np.unique(true_states)
|
|
classes = np.unique(true_states)
|
|
|
|
+ if colors is None:
|
|
|
|
+ colors = [plt.get_cmap('tab10')(i)[:3] for i in range(len(classes))]
|
|
for i, c in enumerate(classes):
|
|
for i, c in enumerate(classes):
|
|
ax.fill_between(np.linspace(*time_range, len(true_states)), 0, 1,
|
|
ax.fill_between(np.linspace(*time_range, len(true_states)), 0, 1,
|
|
- where=(true_states == c), alpha=0.6, color=plt.get_cmap('tab10')(i)[:3])
|
|
|
|
|
|
+ where=(true_states == c), alpha=0.6, color=colors[i])
|
|
return ax
|
|
return ax
|