@@ -23,7 +23,7 @@ def mc_confusion_matrix(
23
23
tick_fontsize : int = 12 ,
24
24
xtick_rotation : int = None ,
25
25
ytick_rotation : int = None ,
26
- normalize : bool = True ,
26
+ normalize : str = None ,
27
27
cmap : matplotlib .colors .Colormap | str = None ,
28
28
title : bool = True ,
29
29
) -> plt .Figure :
@@ -51,9 +51,11 @@ def mc_confusion_matrix(
51
51
Rotation of x-axis tick labels (helps with long model names).
52
52
ytick_rotation: int, optional, default: None
53
53
Rotation of y-axis tick labels (helps with long model names).
54
- normalize : bool, optional, default: True
55
- A flag for normalization of the confusion matrix.
56
- If True, each row of the confusion matrix is normalized to sum to 1.
54
+ normalize : {'true', 'pred', 'all'}, default=None
55
+ Passed to sklearn.metrics.confusion_matrix.
56
+ Normalizes confusion matrix over the true (rows), predicted (columns)
57
+ conditions or all the population. If None, confusion matrix will not be
58
+ normalized.
57
59
cmap : matplotlib.colors.Colormap or str, optional, default: None
58
60
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
59
61
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
@@ -77,29 +79,26 @@ def mc_confusion_matrix(
77
79
pred_models = ops .argmax (pred_models , axis = 1 )
78
80
79
81
# Compute confusion matrix
80
- cm = confusion_matrix (true_models , pred_models )
81
-
82
- # if normalize:
83
- # # Sum along rows and keep dimensions for broadcasting
84
- # cm_sum = ops.sum(cm, axis=1, keepdims=True)
85
- #
86
- # # Broadcast division for normalization
87
- # cm_normalized = cm / cm_sum
82
+ cm = confusion_matrix (true_models , pred_models , normalize = normalize )
88
83
89
84
# Initialize figure
90
85
fig , ax = make_figure (1 , 1 , figsize = fig_size )
86
+ ax = ax [0 ]
91
87
im = ax .imshow (cm , interpolation = "nearest" , cmap = cmap )
92
88
cbar = ax .figure .colorbar (im , ax = ax , shrink = 0.75 )
93
89
94
90
cbar .ax .tick_params (labelsize = value_fontsize )
95
91
96
- ax .set ( xticks = ops . arange ( cm . shape [ 1 ]), yticks = ops . arange (cm .shape [0 ]))
92
+ ax .set_xticks ( range (cm .shape [0 ]))
97
93
ax .set_xticklabels (model_names , fontsize = tick_fontsize )
98
94
if xtick_rotation :
99
95
plt .xticks (rotation = xtick_rotation , ha = "right" )
96
+
97
+ ax .set_yticks (range (cm .shape [1 ]))
100
98
ax .set_yticklabels (model_names , fontsize = tick_fontsize )
101
99
if ytick_rotation :
102
100
plt .yticks (rotation = ytick_rotation )
101
+
103
102
ax .set_xlabel ("Predicted model" , fontsize = label_fontsize )
104
103
ax .set_ylabel ("True model" , fontsize = label_fontsize )
105
104
0 commit comments