Skip to content

Commit cb785b5

Browse files
Kucharssimhan-ol
authored andcommitted
fix mc_calibration plot (#294)
1 parent 76041c2 commit cb785b5

File tree

4 files changed

+82
-47
lines changed

4 files changed

+82
-47
lines changed

bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,41 +68,42 @@ def mc_calibration(
6868

6969
# Gather plot data and metadata into a dictionary
7070
plot_data = prepare_plot_data(
71-
estimates=pred_models,
72-
ground_truths=true_models,
71+
targets=pred_models,
72+
references=true_models,
7373
variable_names=model_names,
7474
num_col=num_col,
7575
num_row=num_row,
7676
figsize=figsize,
77+
default_name="M",
7778
)
7879

7980
# Compute calibration
8081
cal_errors, true_probs, pred_probs = expected_calibration_error(
81-
plot_data["ground_truths"], plot_data["estimates"], num_bins
82+
plot_data["references"], plot_data["targets"], num_bins
8283
)
8384

8485
for j, ax in enumerate(plot_data["axes"].flat):
8586
# Plot calibration curve
86-
ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color)
87+
ax.plot(pred_probs[j], true_probs[j], "o-", color=color)
8788

8889
# Plot PMP distribution over bins
8990
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
90-
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
91-
ax[j].hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
91+
norm_weights = np.ones_like(plot_data["targets"]) / len(plot_data["targets"])
92+
ax.hist(plot_data["targets"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
9293

9394
# Plot AB line
94-
ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
95+
ax.plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
9596

9697
# Tweak plot
97-
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
98-
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
99-
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
100-
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
98+
ax.set_xlim([0 - epsilon, 1 + epsilon])
99+
ax.set_ylim([0 - epsilon, 1 + epsilon])
100+
ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
101+
ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
101102

102103
# Add ECE label
103104
add_metric(
104-
ax[j],
105-
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}",
105+
ax,
106+
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$",
106107
metric_value=cal_errors[j],
107108
metric_fontsize=metric_fontsize,
108109
)

bayesflow/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from .optimal_transport import optimal_transport
3737
from .plot_utils import (
38-
check_posterior_prior_shapes,
38+
check_estimates_prior_shapes,
3939
prepare_plot_data,
4040
add_titles_and_labels,
4141
prettify_subplots,

bayesflow/utils/plot_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import matplotlib.pyplot as plt
55
import seaborn as sns
66

7-
from .validators import check_posterior_prior_shapes
7+
from .validators import check_estimates_prior_shapes
88
from .dict_utils import dicts_to_arrays
99

1010

@@ -52,7 +52,7 @@ def prepare_plot_data(
5252
plot_data = dicts_to_arrays(
5353
targets=targets, references=references, variable_names=variable_names, default_name=default_name
5454
)
55-
check_posterior_prior_shapes(plot_data["targets"], plot_data["references"])
55+
check_estimates_prior_shapes(plot_data["targets"], plot_data["references"])
5656

5757
# Configure layout
5858
num_row, num_col = set_layout(plot_data["num_variables"], num_row, num_col, stacked)

bayesflow/utils/validators.py

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,82 @@ def check_lengths_same(*args):
88
raise ValueError(f"All tuple arguments must have the same length, but lengths are {tuple(map(len, args))}.")
99

1010

11-
def check_posterior_prior_shapes(post_variables: Tensor, prior_variables: Tensor):
11+
def check_prior_shapes(variables: Tensor):
1212
"""
13-
Checks requirements for the shapes of posterior and prior draws as
14-
necessitated by most diagnostic functions.
13+
Checks the shape of posterior draws as required by most diagnostic functions
1514
1615
Parameters
1716
----------
18-
post_samples : Tensor of shape (num_data_sets, num_post_draws, num_params)
19-
The posterior draws obtained from num_data_sets
20-
prior_samples : Tensor of shape (num_data_sets, num_params)
21-
The prior draws obtained for generating num_data_sets
22-
23-
Raises
24-
------
25-
ShapeError
26-
If there is a deviation form the expected shapes of `post_samples` and `prior_samples`.
17+
variables : Tensor of shape (num_data_sets, num_params)
18+
The prior_samples from generating num_data_sets
2719
"""
2820

29-
if len(post_variables.shape) != 3:
21+
if len(variables.shape) != 2:
3022
raise ShapeError(
31-
"post_samples should be a 3-dimensional array, with the "
32-
"first dimension being the number of (simulated) data sets, "
33-
"the second dimension being the number of posterior draws per data set, "
34-
"and the third dimension being the number of parameters (marginal distributions), "
35-
f"but your input has dimensions {len(post_variables.shape)}"
23+
"prior_samples samples should be a 2-dimensional array, with the "
24+
"first dimension being the number of (simulated) data sets / prior_samples draws "
25+
"and the second dimension being the number of variables, "
26+
f"but your input has dimensions {len(variables.shape)}"
3627
)
37-
elif len(prior_variables.shape) != 2:
28+
29+
30+
def check_estimates_shapes(variables: Tensor):
31+
"""
32+
Checks the shape of model-generated predictions (posterior draws, point estimates)
33+
as required by most diagnostic functions
34+
35+
Parameters
36+
----------
37+
variables : Tensor of shape (num_data_sets, num_post_draws, num_params)
38+
The prior_samples from generating num_data_sets
39+
"""
40+
if len(variables.shape) != 2 and len(variables.shape) != 3:
3841
raise ShapeError(
39-
"prior_samples should be a 2-dimensional array, with the "
40-
"first dimension being the number of (simulated) data sets / prior draws "
41-
"and the second dimension being the number of parameters (marginal distributions), "
42-
f"but your input has dimensions {len(prior_variables.shape)}"
42+
"estimates should be a 2- or 3-dimensional array, with the "
43+
"first dimension being the number of data sets, "
44+
"(optional) second dimension the number of posterior draws per data set, "
45+
"and the last dimension the number of estimated variables, "
46+
f"but your input has dimensions {len(variables.shape)}"
4347
)
44-
elif post_variables.shape[0] != prior_variables.shape[0]:
48+
49+
50+
def check_consistent_shapes(estimates: Tensor, prior_samples: Tensor):
51+
"""
52+
Checks whether the model-generated predictions (posterior draws, point estimates) and
53+
prior_samples have consistent leading (num_data_sets) and trailing (num_params) dimensions
54+
"""
55+
if estimates.shape[0] != prior_samples.shape[0]:
4556
raise ShapeError(
46-
"The number of elements over the first dimension of post_samples and prior_samples"
47-
f"should match, but post_samples has {post_variables.shape[0]} and prior_samples has "
48-
f"{prior_variables.shape[0]} elements, respectively."
57+
"The number of elements over the first dimension of estimates and prior_samples"
58+
f"should match, but estimates have {estimates.shape[0]} and prior_samples has "
59+
f"{prior_samples.shape[0]} elements, respectively."
4960
)
50-
elif post_variables.shape[-1] != prior_variables.shape[-1]:
61+
if estimates.shape[-1] != prior_samples.shape[-1]:
5162
raise ShapeError(
52-
"The number of elements over the last dimension of post_samples and prior_samples"
53-
f"should match, but post_samples has {post_variables.shape[1]} and prior_samples has "
54-
f"{prior_variables.shape[-1]} elements, respectively."
63+
"The number of elements over the last dimension of estimates and prior_samples"
64+
f"should match, but estimates has {estimates.shape[0]} and prior_samples has "
65+
f"{prior_samples.shape[0]} elements, respectively."
5566
)
67+
68+
69+
def check_estimates_prior_shapes(estimates: Tensor, prior_samples: Tensor):
70+
"""
71+
Checks requirements for the shapes of estimates and prior_samples draws as
72+
necessitated by most diagnostic functions.
73+
74+
Parameters
75+
----------
76+
estimates : Tensor of shape (num_data_sets, num_post_draws, num_params) or (num_data_sets, num_params)
77+
The model-generated predictions (posterior draws, point estimates) obtained from num_data_sets
78+
prior_samples : Tensor of shape (num_data_sets, num_params)
79+
The prior_samples draws obtained for generating num_data_sets
80+
81+
Raises
82+
------
83+
ShapeError
84+
If there is a deviation form the expected shapes of `estimates` and `estimates`.
85+
"""
86+
87+
check_estimates_shapes(estimates)
88+
check_prior_shapes(prior_samples)
89+
check_consistent_shapes(estimates, prior_samples)

0 commit comments

Comments
 (0)