Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss progress tests #176

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/simulators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .sequential_simulator import SequentialSimulator
from .two_moons_simulator import TwoMoonsSimulator
from .normal_simulator import NormalSimulator
20 changes: 20 additions & 0 deletions bayesflow/simulators/normal_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import keras

from .simulator import Simulator
from ..types import Shape, Tensor


class NormalSimulator(Simulator):
"""TODO: Docstring"""

def sample(self, batch_shape: Shape, num_observations: int = 32) -> dict[str, Tensor]:
mean = keras.random.normal(batch_shape + (2,), 0.0, 0.1)
mean = keras.ops.repeat(mean[:, None], num_observations, 1)

std = keras.ops.exp(keras.random.normal(batch_shape + (2,), 0.0, 0.1))
std = keras.ops.repeat(std[:, None], num_observations, 1)

noise = keras.random.normal(batch_shape + (num_observations, 2))

x = mean + std * noise
return dict(mean=mean, std=std, x=x)
51 changes: 28 additions & 23 deletions tests/test_approximators/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import keras
import pytest

import bayesflow as bf

@pytest.fixture()
def batch_size():
return 8


@pytest.fixture()
Expand All @@ -11,40 +13,43 @@ def summary_network():

@pytest.fixture()
def inference_network():
network = keras.Sequential([keras.layers.Dense(10)])
network.compile(loss="mse")
return network
from bayesflow.networks import CouplingFlow

return CouplingFlow()


@pytest.fixture()
def approximator(inference_network, summary_network):
return bf.Approximator(
from bayesflow import Approximator

return Approximator(
inference_network=inference_network,
summary_network=summary_network,
inference_variables=[],
inference_conditions=[],
summary_variables=[],
summary_conditions=[],
inference_variables=["mean", "std"],
inference_conditions=["x"],
)


@pytest.fixture()
def dataset():
batch_size = 16
batches_per_epoch = 4
parameter_sets = batch_size * batches_per_epoch
observations_per_parameter_set = 32
def simulator():
from bayesflow.simulators import NormalSimulator

mean = keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2))
std = keras.ops.exp(keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2)))
return NormalSimulator()

mean = keras.ops.repeat(mean[:, None], observations_per_parameter_set, 1)
std = keras.ops.repeat(std[:, None], observations_per_parameter_set, 1)

noise = keras.random.normal(shape=(parameter_sets, observations_per_parameter_set, 2))
@pytest.fixture()
def train_dataset(simulator, batch_size):
from bayesflow import OfflineDataset

num_batches = 4
data = simulator.sample((num_batches * batch_size,))
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)

x = mean + std * noise

data = dict(observables=dict(x=x), parameters=dict(mean=mean, std=std))
@pytest.fixture()
def validation_dataset(simulator, batch_size):
from bayesflow import OfflineDataset

return bf.datasets.OfflineDataset(data, batch_size=batch_size, batches_per_epoch=batches_per_epoch)
num_batches = 2
data = simulator.sample((num_batches * batch_size,))
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)
29 changes: 29 additions & 0 deletions tests/test_approximators/test_fit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest
import io
import numpy as np
from contextlib import redirect_stdout


@pytest.mark.skip(reason="not implemented")
Expand All @@ -12,3 +15,29 @@ def test_fit(amortizer, dataset):
amortizer.fit(dataset)

assert amortizer.losses is not None


def test_loss_progress(approximator, train_dataset, validation_dataset):
approximator.compile(optimizer="AdamW")
num_epochs = 3

# Capture ostream and train model
ostream = io.StringIO()
with redirect_stdout(ostream):
history = approximator.fit(train_dataset, validation_data=validation_dataset, epochs=num_epochs).history
output = ostream.getvalue()
ostream.close()

loss_output = [line.replace("\x08", "") for line in output.splitlines() if "loss" in line]

# Test losses are not NaN and that epoch summaries match loss histories
epoch = 0
for loss_stats in loss_output:
content = loss_stats.split()
if "val_loss" in loss_stats:
assert float(content[-4]) == round(history["loss"][epoch], 4)
assert float(content[-1]) == round(history["val_loss"][epoch], 4)
epoch += 1
continue

assert not np.isnan(float(content[-1]))
Loading