Skip to content

Commit

Permalink
reorganize fixtures, improve parametrized test id generation
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 4, 2024
1 parent cfb4638 commit 5a72c85
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,63 @@ def pytest_runtest_setup(item):
pytest.skip(f"Skipping backend '{backend}' for test {item}, which is registered for backends {test_backends}.")


@pytest.fixture(autouse=True, scope="function")
def random_seed():
seed = 0
keras.utils.set_random_seed(seed)
return seed
def pytest_make_parametrize_id(config, val, argname):
return f"{argname}={repr(val)}"


@pytest.fixture(params=[1, 2, 3, 4], scope="session")
def batch_size(request):
return request.param


@pytest.fixture()
@pytest.fixture(scope="function")
def coupling_flow():
from bayesflow.networks import CouplingFlow

return CouplingFlow(depth=2, subnet_kwargs=dict(depth=2, width=64))
return CouplingFlow(depth=2, subnet_kwargs=dict(depth=2, width=32))


@pytest.fixture()
@pytest.fixture(params=["two_moons"], scope="session")
def dataset(request):
return request.getfixturevalue(request.param)


@pytest.fixture(scope="function")
def flow_matching():
from bayesflow.networks import FlowMatching

return FlowMatching(network_kwargs=dict(depth=2, width=64))
return FlowMatching(network_kwargs=dict(depth=2, width=32))


@pytest.fixture(params=["coupling_flow", "flow_matching"])
@pytest.fixture(params=["coupling_flow", "flow_matching"], scope="function")
def inference_network(request):
return request.getfixturevalue(request.param)


@pytest.fixture(params=["inference_network", "summary_network"], scope="function")
def network(request):
return request.getfixturevalue(request.param)


@pytest.fixture(autouse=True, scope="function")
def random_seed():
seed = 0
keras.utils.set_random_seed(seed)
return seed


@pytest.fixture(params=[None], scope="function")
def summary_network(request):
if request.param is None:
return None
return request.getfixturevalue(request.param)


@pytest.fixture(scope="session")
def two_moons(batch_size):
from bayesflow.datasets import OfflineDataset
from bayesflow.simulators import TwoMoonsSimulator

simulator = TwoMoonsSimulator()
samples = simulator.sample((4 * batch_size,))
return OfflineDataset(samples, batch_size=batch_size)

0 comments on commit 5a72c85

Please sign in to comment.