Skip to content

Commit 6bfc8c9

Browse files
committed
fix mc_confusion_matrix plot
1 parent a82b98d commit 6bfc8c9

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

bayesflow/diagnostics/plots/mc_confusion_matrix.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def mc_confusion_matrix(
2323
tick_fontsize: int = 12,
2424
xtick_rotation: int = None,
2525
ytick_rotation: int = None,
26-
normalize: bool = True,
26+
normalize: str = None,
2727
cmap: matplotlib.colors.Colormap | str = None,
2828
title: bool = True,
2929
) -> plt.Figure:
@@ -51,9 +51,11 @@ def mc_confusion_matrix(
5151
Rotation of x-axis tick labels (helps with long model names).
5252
ytick_rotation: int, optional, default: None
5353
Rotation of y-axis tick labels (helps with long model names).
54-
normalize : bool, optional, default: True
55-
A flag for normalization of the confusion matrix.
56-
If True, each row of the confusion matrix is normalized to sum to 1.
54+
normalize : {'true', 'pred', 'all'}, default=None
55+
Passed to sklearn.metrics.confusion_matrix.
56+
Normalizes confusion matrix over the true (rows), predicted (columns)
57+
conditions or all the population. If None, confusion matrix will not be
58+
normalized.
5759
cmap : matplotlib.colors.Colormap or str, optional, default: None
5860
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
5961
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
@@ -77,29 +79,26 @@ def mc_confusion_matrix(
7779
pred_models = ops.argmax(pred_models, axis=1)
7880

7981
# Compute confusion matrix
80-
cm = confusion_matrix(true_models, pred_models)
81-
82-
# if normalize:
83-
# # Sum along rows and keep dimensions for broadcasting
84-
# cm_sum = ops.sum(cm, axis=1, keepdims=True)
85-
#
86-
# # Broadcast division for normalization
87-
# cm_normalized = cm / cm_sum
82+
cm = confusion_matrix(true_models, pred_models, normalize=normalize)
8883

8984
# Initialize figure
9085
fig, ax = make_figure(1, 1, figsize=fig_size)
86+
ax = ax[0]
9187
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
9288
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75)
9389

9490
cbar.ax.tick_params(labelsize=value_fontsize)
9591

96-
ax.set(xticks=ops.arange(cm.shape[1]), yticks=ops.arange(cm.shape[0]))
92+
ax.set_xticks(range(cm.shape[0]))
9793
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
9894
if xtick_rotation:
9995
plt.xticks(rotation=xtick_rotation, ha="right")
96+
97+
ax.set_yticks(range(cm.shape[1]))
10098
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
10199
if ytick_rotation:
102100
plt.yticks(rotation=ytick_rotation)
101+
103102
ax.set_xlabel("Predicted model", fontsize=label_fontsize)
104103
ax.set_ylabel("True model", fontsize=label_fontsize)
105104

0 commit comments

Comments
 (0)