Skip to content

Commit 8b47c03

Browse files
Two moons test fit (#175)
* Squashed commit of the following: commit 320f1ae Author: lars <[email protected]> Date: Tue Jun 18 16:23:01 2024 +0200 fix two moons simulated dtype commit 27f99cd Author: lars <[email protected]> Date: Tue Jun 18 16:09:45 2024 +0200 fix data modification for tensorflow compiled mode commit c8060fc Merge: 3150d11 e2355de Author: lars <[email protected]> Date: Tue Jun 18 15:35:59 2024 +0200 Merge remote-tracking branch 'origin/streamlined-backend' into streamlined-backend commit 3150d11 Author: lars <[email protected]> Date: Tue Jun 18 15:35:52 2024 +0200 add JAX Approximator finalize all Approximators commit e2355de Author: Chase Grajeda <[email protected]> Date: Tue Jun 18 22:15:37 2024 +0900 Configurator Unit Tests (#174) * First additions Added __init__.py for test module. Added test_configurators.py. Added basic fixtures and construction tests. * Remaining tests Added remaining unit tests * Added conftest Separated fixtures and placed them in conftest.py * Added requested changes Added batch_size, set_size, and num_features parameterizations in conftest.py. Combined repetitive fixtures in conftest.py. Combined repetitive tests in test_configurators.py. Parameterized Configurator initialization in conftest.py. Parameterized parameter selection in conftest.py. Removed initialization tests in test_configurators.py. Added summary_inputs and summary_conditions to parameters. Changed instances of '==None' to 'is None'. Removed 'config=Configurator' instances in test_configurators.py. * Added loss test Added test for post-training loss < pre-training loss to test_fit.py::test_fit * Added vanishing weights test Added test in test_fit.py::test_fit for vanishing weights * Added simulator test Added test to test_fit.py for verifying the simulator produces random and consistent data * Added MMD test Added MMD test to test_two_moons.py. Added MMD method to utils/ops.py. Added test_dataset to test_two_moons/conftest.py. * Linting adjustments Added auto-formatting changes from ruff
1 parent 07e60dd commit 8b47c03

File tree

4 files changed

+69
-9
lines changed

4 files changed

+69
-9
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,4 @@ def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> d
137137

138138
loss = keras.losses.mean_squared_error(predicted_velocity, target_velocity)
139139

140-
return {"loss": loss}
140+
return {"loss": loss}

tests/test_two_moons/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,14 @@ def validation_dataset(simulator, batch_size):
5050
from bayesflow import OfflineDataset
5151

5252
num_batches = 4
53-
data = simulator.sample((4 * batch_size,))
53+
data = simulator.sample((num_batches * batch_size,))
54+
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)
55+
56+
57+
@pytest.fixture()
58+
def test_dataset(simulator, batch_size):
59+
from bayesflow import OfflineDataset
60+
61+
num_batches = 16
62+
data = simulator.sample((num_batches * batch_size,))
5463
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)

tests/test_two_moons/test_two_moons.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import keras
22
import pytest
33

4-
from tests.utils import assert_models_equal
4+
from tests.utils import assert_models_equal, max_mean_discrepancy
55
from tests.utils import InterruptFitCallback, FitInterruptedError
66

77

@@ -11,14 +11,38 @@ def test_compile(approximator, random_samples, jit_compile):
1111

1212

1313
@pytest.mark.parametrize("jit_compile", [False, True])
14-
def test_fit(approximator, train_dataset, validation_dataset, jit_compile):
15-
# TODO: verify the model learns something by comparing a metric before and after training
16-
approximator.compile(jit_compile=jit_compile)
17-
approximator.fit(
14+
def test_fit(approximator, train_dataset, validation_dataset, test_dataset, jit_compile):
15+
# TODO: Refactor to use approximator.sample() when implemented (instead of calling the inference network directly)
16+
17+
approximator.compile(jit_compile=jit_compile, loss=keras.losses.KLDivergence())
18+
inf_vars = approximator.configurator.configure_inference_variables(test_dataset.data)
19+
inf_conds = approximator.configurator.configure_inference_conditions(test_dataset.data)
20+
y = test_dataset.data["x"]
21+
22+
pre_loss = approximator.compute_metrics(train_dataset.data)["loss"]
23+
pre_val_loss = approximator.compute_metrics(validation_dataset.data)["loss"]
24+
x_before = approximator.inference_network(inf_vars, conditions=inf_conds)
25+
mmd_before = max_mean_discrepancy(x_before, y)
26+
27+
history = approximator.fit(
1828
train_dataset,
1929
validation_data=validation_dataset,
20-
epochs=2,
21-
)
30+
epochs=3,
31+
).history
32+
x_after = approximator.inference_network(inf_vars, conditions=inf_conds)
33+
mmd_after = max_mean_discrepancy(x_after, y)
34+
35+
# Test model weights have not vanished
36+
for layer in approximator.layers:
37+
for weight in layer.weights:
38+
assert not keras.ops.any(keras.ops.isnan(weight)).numpy()
39+
40+
# Test KLD loss and validation loss decrease after training
41+
assert history["loss"][-1] < pre_loss
42+
assert history["val_loss"][-1] < pre_val_loss
43+
44+
# Test MMD improved after training
45+
assert mmd_after < mmd_before
2246

2347

2448
@pytest.mark.parametrize("jit_compile", [False, True])

tests/utils/ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,30 @@ def isclose(x1, x2, rtol=1e-5, atol=1e-8):
77

88
def allclose(x1, x2, rtol=1e-5, atol=1e-8):
99
return keras.ops.all(isclose(x1, x2, rtol, atol))
10+
11+
12+
def max_mean_discrepancy(x, y):
13+
# Computes the Max Mean Discrepancy between samples of two distributions
14+
xx = keras.ops.matmul(x, keras.ops.transpose(x))
15+
yy = keras.ops.matmul(y, keras.ops.transpose(y))
16+
zz = keras.ops.matmul(x, keras.ops.transpose(y))
17+
18+
rx = keras.ops.broadcast_to(keras.ops.expand_dims(keras.ops.diag(xx), 0), xx.shape)
19+
ry = keras.ops.broadcast_to(keras.ops.expand_dims(keras.ops.diag(yy), 0), yy.shape)
20+
21+
dxx = keras.ops.transpose(rx) + rx - 2.0 * xx
22+
dyy = keras.ops.transpose(ry) + ry - 2.0 * yy
23+
dxy = keras.ops.transpose(rx) + ry - 2.0 * zz
24+
25+
XX = keras.ops.zeros(xx.shape)
26+
YY = keras.ops.zeros(yy.shape)
27+
XY = keras.ops.zeros(zz.shape)
28+
29+
# RBF scaling
30+
bandwidth = [10, 15, 20, 50]
31+
for a in bandwidth:
32+
XX += keras.ops.exp(-0.5 * dxx / a)
33+
YY += keras.ops.exp(-0.5 * dyy / a)
34+
XY += keras.ops.exp(-0.5 * dxy / a)
35+
36+
return keras.ops.mean(XX + YY - 2.0 * XY)

0 commit comments

Comments
 (0)