Skip to content

Commit ff35ca6

Browse files
Merge pull request #176 from Chase-Grajeda/loss-progress-tests
Loss progress tests
2 parents 649436c + 2da5ec7 commit ff35ca6

File tree

4 files changed

+78
-23
lines changed

4 files changed

+78
-23
lines changed

bayesflow/simulators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .sequential_simulator import SequentialSimulator
22
from .two_moons_simulator import TwoMoonsSimulator
3+
from .normal_simulator import NormalSimulator
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import keras
2+
3+
from .simulator import Simulator
4+
from ..types import Shape, Tensor
5+
6+
7+
class NormalSimulator(Simulator):
8+
"""TODO: Docstring"""
9+
10+
def sample(self, batch_shape: Shape, num_observations: int = 32) -> dict[str, Tensor]:
11+
mean = keras.random.normal(batch_shape + (2,), 0.0, 0.1)
12+
mean = keras.ops.repeat(mean[:, None], num_observations, 1)
13+
14+
std = keras.ops.exp(keras.random.normal(batch_shape + (2,), 0.0, 0.1))
15+
std = keras.ops.repeat(std[:, None], num_observations, 1)
16+
17+
noise = keras.random.normal(batch_shape + (num_observations, 2))
18+
19+
x = mean + std * noise
20+
return dict(mean=mean, std=std, x=x)
Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import keras
21
import pytest
32

4-
import bayesflow as bf
3+
4+
@pytest.fixture()
5+
def batch_size():
6+
return 8
57

68

79
@pytest.fixture()
@@ -11,40 +13,43 @@ def summary_network():
1113

1214
@pytest.fixture()
1315
def inference_network():
14-
network = keras.Sequential([keras.layers.Dense(10)])
15-
network.compile(loss="mse")
16-
return network
16+
from bayesflow.networks import CouplingFlow
17+
18+
return CouplingFlow()
1719

1820

1921
@pytest.fixture()
2022
def approximator(inference_network, summary_network):
21-
return bf.Approximator(
23+
from bayesflow import Approximator
24+
25+
return Approximator(
2226
inference_network=inference_network,
2327
summary_network=summary_network,
24-
inference_variables=[],
25-
inference_conditions=[],
26-
summary_variables=[],
27-
summary_conditions=[],
28+
inference_variables=["mean", "std"],
29+
inference_conditions=["x"],
2830
)
2931

3032

3133
@pytest.fixture()
32-
def dataset():
33-
batch_size = 16
34-
batches_per_epoch = 4
35-
parameter_sets = batch_size * batches_per_epoch
36-
observations_per_parameter_set = 32
34+
def simulator():
35+
from bayesflow.simulators import NormalSimulator
3736

38-
mean = keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2))
39-
std = keras.ops.exp(keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2)))
37+
return NormalSimulator()
4038

41-
mean = keras.ops.repeat(mean[:, None], observations_per_parameter_set, 1)
42-
std = keras.ops.repeat(std[:, None], observations_per_parameter_set, 1)
4339

44-
noise = keras.random.normal(shape=(parameter_sets, observations_per_parameter_set, 2))
40+
@pytest.fixture()
41+
def train_dataset(simulator, batch_size):
42+
from bayesflow import OfflineDataset
43+
44+
num_batches = 4
45+
data = simulator.sample((num_batches * batch_size,))
46+
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)
4547

46-
x = mean + std * noise
4748

48-
data = dict(observables=dict(x=x), parameters=dict(mean=mean, std=std))
49+
@pytest.fixture()
50+
def validation_dataset(simulator, batch_size):
51+
from bayesflow import OfflineDataset
4952

50-
return bf.datasets.OfflineDataset(data, batch_size=batch_size, batches_per_epoch=batches_per_epoch)
53+
num_batches = 2
54+
data = simulator.sample((num_batches * batch_size,))
55+
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)

tests/test_approximators/test_fit.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import pytest
2+
import io
3+
import numpy as np
4+
from contextlib import redirect_stdout
25

36

47
@pytest.mark.skip(reason="not implemented")
@@ -12,3 +15,29 @@ def test_fit(amortizer, dataset):
1215
amortizer.fit(dataset)
1316

1417
assert amortizer.losses is not None
18+
19+
20+
def test_loss_progress(approximator, train_dataset, validation_dataset):
21+
approximator.compile(optimizer="AdamW")
22+
num_epochs = 3
23+
24+
# Capture ostream and train model
25+
ostream = io.StringIO()
26+
with redirect_stdout(ostream):
27+
history = approximator.fit(train_dataset, validation_data=validation_dataset, epochs=num_epochs).history
28+
output = ostream.getvalue()
29+
ostream.close()
30+
31+
loss_output = [line.replace("\x08", "") for line in output.splitlines() if "loss" in line]
32+
33+
# Test losses are not NaN and that epoch summaries match loss histories
34+
epoch = 0
35+
for loss_stats in loss_output:
36+
content = loss_stats.split()
37+
if "val_loss" in loss_stats:
38+
assert float(content[-4]) == round(history["loss"][epoch], 4)
39+
assert float(content[-1]) == round(history["val_loss"][epoch], 4)
40+
epoch += 1
41+
continue
42+
43+
assert not np.isnan(float(content[-1]))

0 commit comments

Comments
 (0)