From 5c4103af73d759797022e091c2fff0f6ad3889ca Mon Sep 17 00:00:00 2001 From: Radev Date: Wed, 15 May 2024 09:31:12 -0400 Subject: [PATCH] Hopefully fixed test for base coupling flow --- tests/test_two_moons/conftest.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index c04eb3047..ff0cf5796 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -44,26 +44,9 @@ def dataset(joint_distribution): @pytest.fixture() -def subnet_constructor(): - class Subnet(keras.Layer): - def __init__(self, out_features): - super().__init__() - self.network = keras.Sequential([ - keras.layers.Dense(32, activation="relu"), - keras.layers.Dense(out_features, kernel_initializer=keras.initializers.Zeros(), - bias_initializer=keras.initializers.Zeros()) - ]) - - def call(self, x): - return self.network(x) - - return Subnet - - -@pytest.fixture() -def inference_network(subnet_constructor): +def inference_network(): return bf.networks.CouplingFlow.all_in_one( - subnet_constructor=subnet_constructor, + subnet_builder="default", target_dim=2, num_layers=2, transform="affine",