Skip to content

Commit

Permalink
globalize more fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 15, 2024
1 parent 970679e commit f3e093a
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,33 @@ def batch_size(request):
return request.param


@pytest.fixture(params=[None, 2, 3], scope="session")
def conditions_size(request):
return request.param


@pytest.fixture(scope="function")
def coupling_flow(request):
def coupling_flow():
from bayesflow.networks import CouplingFlow

return CouplingFlow(depth=2, subnet_kwargs=dict(depth=2, width=32))
return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(depth=2, width=32))


@pytest.fixture(params=["two_moons"], scope="session")
def dataset(request):
return request.getfixturevalue(request.param)


@pytest.fixture(params=[2, 3], scope="session")
def feature_size(request):
return request.param


@pytest.fixture(scope="function")
def flow_matching():
from bayesflow.networks import FlowMatching

return FlowMatching(subnet="resnet", network_kwargs=dict(depth=2, width=32))
return FlowMatching(subnet="mlp", subnet_kwargs=dict(depth=2, width=32))


@pytest.fixture(params=["coupling_flow", "flow_matching"], scope="function")
Expand All @@ -53,13 +63,36 @@ def network(request):
return request.getfixturevalue(request.param)


@pytest.fixture(scope="session")
def random_conditions(batch_size, conditions_size):
if conditions_size is None:
return None

return keras.random.normal((batch_size, conditions_size))


@pytest.fixture(scope="session")
def random_samples(batch_size, feature_size):
return keras.random.normal((batch_size, feature_size))


@pytest.fixture(scope="function", autouse=True)
def random_seed():
seed = 0
keras.utils.set_random_seed(seed)
return seed


@pytest.fixture(scope="session")
def random_set(batch_size, set_size, feature_size):
return keras.random.normal((batch_size, set_size, feature_size))


@pytest.fixture(params=[2, 3], scope="session")
def set_size(request):
return request.param


@pytest.fixture(params=["two_moons"], scope="session")
def simulator(request):
return request.getfixturevalue(request.param)
Expand Down

0 comments on commit f3e093a

Please sign in to comment.