Skip to content

Commit

Permalink
Hopefully fixed test for base coupling flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Radev committed May 15, 2024
1 parent 0875002 commit 5c4103a
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions tests/test_two_moons/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,9 @@ def dataset(joint_distribution):


@pytest.fixture()
def subnet_constructor():
class Subnet(keras.Layer):
def __init__(self, out_features):
super().__init__()
self.network = keras.Sequential([
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(out_features, kernel_initializer=keras.initializers.Zeros(),
bias_initializer=keras.initializers.Zeros())
])

def call(self, x):
return self.network(x)

return Subnet


@pytest.fixture()
def inference_network(subnet_constructor):
def inference_network():
return bf.networks.CouplingFlow.all_in_one(
subnet_constructor=subnet_constructor,
subnet_builder="default",
target_dim=2,
num_layers=2,
transform="affine",
Expand Down

0 comments on commit 5c4103a

Please sign in to comment.