Skip to content

Commit 1a85fc8

Browse files
committed
Merge branch 'dev' into point-estimation
2 parents ce07855 + cb785b5 commit 1a85fc8

39 files changed

+1187
-818
lines changed

bayesflow/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def setup():
3636

3737
torch.autograd.set_grad_enabled(False)
3838

39+
from bayesflow.utils import logging
40+
41+
logging.info(f"Using backend {keras.backend.backend()!r}")
42+
3943

4044
# call and clean up namespace
4145
setup()

bayesflow/adapters/adapter.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable, Sequence
1+
from collections.abc import Callable, MutableSequence, Sequence
22

33
import numpy as np
44
from keras.saving import (
@@ -26,17 +26,16 @@
2626
ToArray,
2727
Transform,
2828
)
29-
3029
from .transforms.filter_transform import Predicate
3130

3231

3332
@serializable(package="bayesflow.adapters")
34-
class Adapter:
33+
class Adapter(MutableSequence[Transform]):
3534
def __init__(self, transforms: Sequence[Transform] | None = None):
3635
if transforms is None:
3736
transforms = []
3837

39-
self.transforms = transforms
38+
self.transforms = list(transforms)
4039

4140
@staticmethod
4241
def create_default(inference_variables: Sequence[str]) -> "Adapter":
@@ -77,12 +76,70 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) ->
7776
return self.forward(data, **kwargs)
7877

7978
def __repr__(self):
80-
return f"Adapter([{' -> '.join(map(repr, self.transforms))}])"
79+
result = ""
80+
for i, transform in enumerate(self):
81+
result += f"{i}: {transform!r}"
82+
if i != len(self) - 1:
83+
result += " -> "
8184

82-
def add_transform(self, transform: Transform):
83-
self.transforms.append(transform)
85+
return f"Adapter([{result}])"
86+
87+
# list methods
88+
89+
def append(self, value: Transform) -> "Adapter":
90+
self.transforms.append(value)
8491
return self
8592

93+
def __delitem__(self, key: int | slice):
94+
del self.transforms[key]
95+
96+
def extend(self, values: Sequence[Transform]) -> "Adapter":
97+
if isinstance(values, Adapter):
98+
values = values.transforms
99+
100+
self.transforms.extend(values)
101+
102+
return self
103+
104+
def __getitem__(self, item: int | slice) -> "Adapter":
105+
if isinstance(item, int):
106+
return self.transforms[item]
107+
108+
return Adapter(self.transforms[item])
109+
110+
def insert(self, index: int, value: Transform | Sequence[Transform]) -> "Adapter":
111+
if isinstance(value, Adapter):
112+
value = value.transforms
113+
114+
if isinstance(value, Sequence):
115+
# convenience: Adapters are always flat
116+
self.transforms = self.transforms[:index] + list(value) + self.transforms[index:]
117+
else:
118+
self.transforms.insert(index, value)
119+
120+
return self
121+
122+
def __setitem__(self, key: int | slice, value: Transform | Sequence[Transform]) -> "Adapter":
123+
if isinstance(value, Adapter):
124+
value = value.transforms
125+
126+
if isinstance(key, int) and isinstance(value, Sequence):
127+
if key < 0:
128+
key += len(self.transforms)
129+
130+
key = slice(key, key + 1)
131+
132+
self.transforms[key] = value
133+
134+
return self
135+
136+
def __len__(self):
137+
return len(self.transforms)
138+
139+
# adapter methods
140+
141+
add_transform = append
142+
86143
def apply(
87144
self,
88145
*,

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def calibration_ecdf(
176176
titles = ["Stacked ECDFs"]
177177

178178
for ax, title in zip(plot_data["axes"].flat, titles):
179-
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
179+
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
180180
ax.legend(fontsize=legend_fontsize)
181181
ax.set_title(title, fontsize=title_fontsize)
182182

bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,41 +68,42 @@ def mc_calibration(
6868

6969
# Gather plot data and metadata into a dictionary
7070
plot_data = prepare_plot_data(
71-
estimates=pred_models,
72-
ground_truths=true_models,
71+
targets=pred_models,
72+
references=true_models,
7373
variable_names=model_names,
7474
num_col=num_col,
7575
num_row=num_row,
7676
figsize=figsize,
77+
default_name="M",
7778
)
7879

7980
# Compute calibration
8081
cal_errors, true_probs, pred_probs = expected_calibration_error(
81-
plot_data["ground_truths"], plot_data["estimates"], num_bins
82+
plot_data["references"], plot_data["targets"], num_bins
8283
)
8384

8485
for j, ax in enumerate(plot_data["axes"].flat):
8586
# Plot calibration curve
86-
ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color)
87+
ax.plot(pred_probs[j], true_probs[j], "o-", color=color)
8788

8889
# Plot PMP distribution over bins
8990
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
90-
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
91-
ax[j].hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
91+
norm_weights = np.ones_like(plot_data["targets"]) / len(plot_data["targets"])
92+
ax.hist(plot_data["targets"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
9293

9394
# Plot AB line
94-
ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
95+
ax.plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
9596

9697
# Tweak plot
97-
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
98-
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
99-
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
100-
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
98+
ax.set_xlim([0 - epsilon, 1 + epsilon])
99+
ax.set_ylim([0 - epsilon, 1 + epsilon])
100+
ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
101+
ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
101102

102103
# Add ECE label
103104
add_metric(
104-
ax[j],
105-
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}",
105+
ax,
106+
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$",
106107
metric_value=cal_errors[j],
107108
metric_fontsize=metric_fontsize,
108109
)

bayesflow/diagnostics/plots/mmd_hypothesis_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs):
7979

8080
mmd_critical = ops.quantile(mmd_null, 1 - alpha_level)
8181
fill_area_under_kde(
82-
kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area"
82+
kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level * 100)}% rejection area"
8383
)
8484

8585
if truncate_v_lines_at_kde:

bayesflow/networks/consistency_models/continuous_consistency_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def f_teacher(x, t):
249249
ops.cos(t) * ops.sin(t) * self.sigma_data,
250250
)
251251

252-
teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents)
252+
teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents, return_output=True)
253253
teacher_output = ops.stop_gradient(teacher_output)
254254
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)
255255

bayesflow/networks/coupling_flow/couplings/single_coupling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import keras
2-
32
from keras.saving import register_keras_serializable as serializable
43

54
from bayesflow.types import Tensor
@@ -24,6 +23,7 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
2423

2524
output_projector_kwargs = kwargs.get("output_projector_kwargs", {})
2625
output_projector_kwargs.setdefault("kernel_initializer", "zeros")
26+
output_projector_kwargs.setdefault("bias_initializer", "zeros")
2727
self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs)
2828

2929
# serialization: store all parameters necessary to call __init__
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import TypedDict
2+
3+
import keras
4+
5+
from bayesflow.types import Tensor
6+
7+
8+
class Edges(TypedDict):
9+
left: Tensor
10+
right: Tensor
11+
bottom: Tensor
12+
top: Tensor
13+
14+
15+
class Derivatives(TypedDict):
16+
left: Tensor
17+
right: Tensor
18+
19+
20+
def _rational_quadratic_spline(
21+
x: Tensor, edges: Edges, derivatives: Derivatives, inverse: bool = False
22+
) -> (Tensor, Tensor):
23+
# rename variables to match the paper:
24+
25+
# $x^{(k)}$
26+
xk = edges["left"]
27+
28+
# $x^{(k+1)}$
29+
xkp = edges["right"]
30+
31+
# $y^{(k)}$
32+
yk = edges["bottom"]
33+
34+
# $y^{(k+1)}$
35+
ykp = edges["top"]
36+
37+
# $delta^{(k)}$
38+
dk = derivatives["left"]
39+
40+
# $delta^{(k+1)}$
41+
dkp = derivatives["right"]
42+
43+
# commonly used values
44+
dx = xkp - xk
45+
dy = ykp - yk
46+
sk = dy / dx
47+
48+
if not inverse:
49+
xi = (x - xk) / dx
50+
51+
# Eq. 4 in the paper
52+
numerator = dy * (sk * xi**2 + dk * xi * (1 - xi))
53+
denominator = sk + (dkp + dk - 2 * sk) * xi * (1 - xi)
54+
result = yk + numerator / denominator
55+
else:
56+
# rename for clarity
57+
y = x
58+
59+
# Eq. 6-8 in the paper
60+
a = dy * (sk - dk) + (y - yk) * (dkp + dk - 2 * sk)
61+
b = dy * dk - (y - yk) * (dkp + dk - 2 * sk)
62+
c = -sk * (y - yk)
63+
64+
# Eq. 29 in the appendix of the paper
65+
discriminant = b**2 - 4 * a * c
66+
67+
# the discriminant must be positive, even when the spline is called out of bounds
68+
discriminant = keras.ops.maximum(discriminant, 0)
69+
70+
xi = 2 * c / (-b - keras.ops.sqrt(discriminant))
71+
result = xi * dx + xk
72+
73+
# Eq 5 in the paper
74+
numerator = sk**2 * (dkp * xi**2 + 2 * sk * xi * (1 - xi) + dk * (1 - xi) ** 2)
75+
denominator = (sk + (dkp + dk - 2 * sk) * xi * (1 - xi)) ** 2
76+
log_jac = keras.ops.log(numerator) - keras.ops.log(denominator)
77+
78+
if inverse:
79+
log_jac = -log_jac
80+
81+
return result, log_jac

0 commit comments

Comments
 (0)