Skip to content

Commit 56f88b3

Browse files
committed
add todo for density test
1 parent 666bb89 commit 56f88b3

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/test_networks/test_inference_networks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def test_build(inference_network, random_samples, random_conditions):
2525

2626
def test_variable_batch_size(inference_network, random_samples, random_conditions):
2727
# build with one batch size
28-
inference_network(random_samples, conditions=random_conditions)
28+
samples_shape = keras.ops.shape(random_samples)
29+
conditions_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None
30+
inference_network.build(samples_shape, conditions_shape=conditions_shape)
2931

3032
# run with another batch size
3133
batch_sizes = np.random.choice(10, replace=False, size=3)
@@ -81,6 +83,7 @@ def test_cycle_consistency(inference_network, random_samples, random_conditions)
8183
assert allclose(forward_log_density, inverse_log_density, atol=1e-3, rtol=1e-3)
8284

8385

86+
# TODO: make this backend-agnostic
8487
@pytest.mark.torch
8588
def test_density_numerically(inference_network, random_samples, random_conditions):
8689
import torch

0 commit comments

Comments
 (0)