Skip to content

Commit

Permalink
Merge pull request #176 from Chase-Grajeda/loss-progress-tests
Browse files Browse the repository at this point in the history
Loss progress tests
  • Loading branch information
stefanradev93 authored Jul 5, 2024
2 parents 649436c + 2da5ec7 commit ff35ca6
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 23 deletions.
1 change: 1 addition & 0 deletions bayesflow/simulators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .sequential_simulator import SequentialSimulator
from .two_moons_simulator import TwoMoonsSimulator
from .normal_simulator import NormalSimulator
20 changes: 20 additions & 0 deletions bayesflow/simulators/normal_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import keras

from .simulator import Simulator
from ..types import Shape, Tensor


class NormalSimulator(Simulator):
"""TODO: Docstring"""

def sample(self, batch_shape: Shape, num_observations: int = 32) -> dict[str, Tensor]:
mean = keras.random.normal(batch_shape + (2,), 0.0, 0.1)
mean = keras.ops.repeat(mean[:, None], num_observations, 1)

std = keras.ops.exp(keras.random.normal(batch_shape + (2,), 0.0, 0.1))
std = keras.ops.repeat(std[:, None], num_observations, 1)

noise = keras.random.normal(batch_shape + (num_observations, 2))

x = mean + std * noise
return dict(mean=mean, std=std, x=x)
51 changes: 28 additions & 23 deletions tests/test_approximators/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import keras
import pytest

import bayesflow as bf

@pytest.fixture()
def batch_size():
return 8


@pytest.fixture()
Expand All @@ -11,40 +13,43 @@ def summary_network():

@pytest.fixture()
def inference_network():
network = keras.Sequential([keras.layers.Dense(10)])
network.compile(loss="mse")
return network
from bayesflow.networks import CouplingFlow

return CouplingFlow()


@pytest.fixture()
def approximator(inference_network, summary_network):
return bf.Approximator(
from bayesflow import Approximator

return Approximator(
inference_network=inference_network,
summary_network=summary_network,
inference_variables=[],
inference_conditions=[],
summary_variables=[],
summary_conditions=[],
inference_variables=["mean", "std"],
inference_conditions=["x"],
)


@pytest.fixture()
def dataset():
batch_size = 16
batches_per_epoch = 4
parameter_sets = batch_size * batches_per_epoch
observations_per_parameter_set = 32
def simulator():
from bayesflow.simulators import NormalSimulator

mean = keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2))
std = keras.ops.exp(keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2)))
return NormalSimulator()

mean = keras.ops.repeat(mean[:, None], observations_per_parameter_set, 1)
std = keras.ops.repeat(std[:, None], observations_per_parameter_set, 1)

noise = keras.random.normal(shape=(parameter_sets, observations_per_parameter_set, 2))
@pytest.fixture()
def train_dataset(simulator, batch_size):
from bayesflow import OfflineDataset

num_batches = 4
data = simulator.sample((num_batches * batch_size,))
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)

x = mean + std * noise

data = dict(observables=dict(x=x), parameters=dict(mean=mean, std=std))
@pytest.fixture()
def validation_dataset(simulator, batch_size):
from bayesflow import OfflineDataset

return bf.datasets.OfflineDataset(data, batch_size=batch_size, batches_per_epoch=batches_per_epoch)
num_batches = 2
data = simulator.sample((num_batches * batch_size,))
return OfflineDataset(data, workers=4, max_queue_size=num_batches, batch_size=batch_size)
29 changes: 29 additions & 0 deletions tests/test_approximators/test_fit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest
import io
import numpy as np
from contextlib import redirect_stdout


@pytest.mark.skip(reason="not implemented")
Expand All @@ -12,3 +15,29 @@ def test_fit(amortizer, dataset):
amortizer.fit(dataset)

assert amortizer.losses is not None


def test_loss_progress(approximator, train_dataset, validation_dataset):
approximator.compile(optimizer="AdamW")
num_epochs = 3

# Capture ostream and train model
ostream = io.StringIO()
with redirect_stdout(ostream):
history = approximator.fit(train_dataset, validation_data=validation_dataset, epochs=num_epochs).history
output = ostream.getvalue()
ostream.close()

loss_output = [line.replace("\x08", "") for line in output.splitlines() if "loss" in line]

# Test losses are not NaN and that epoch summaries match loss histories
epoch = 0
for loss_stats in loss_output:
content = loss_stats.split()
if "val_loss" in loss_stats:
assert float(content[-4]) == round(history["loss"][epoch], 4)
assert float(content[-1]) == round(history["val_loss"][epoch], 4)
epoch += 1
continue

assert not np.isnan(float(content[-1]))

0 comments on commit ff35ca6

Please sign in to comment.