Skip to content

Commit b2781a5

Browse files
committed
Added basic unit tests
Added test_sequential_simulators to test_simulators/test_simulators.py. Removed extra whitespace in seqential_simulator.py
1 parent e829619 commit b2781a5

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

bayesflow/simulators/sequential_simulator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import keras
2-
32
from collections.abc import Sequence
4-
53
from bayesflow.types import Sampler, Shape, Tensor
64
from bayesflow.utils import batched_call, filter_kwargs
7-
85
from .simulator import Simulator
96

107

tests/test_simulators/test_simulators.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,18 @@ def test_sample_is_random(simulator, batch_size):
1111
expected = keras.ops.size(tensor)
1212
actual = np.size(np.unique(array))
1313
assert actual == expected
14+
15+
16+
def test_sequential_simulators(sequential_two_moons, batch_size):
17+
data = sequential_two_moons.sample((batch_size,))
18+
19+
# Test all keys are present
20+
result_keys = set([key for key in data.keys()])
21+
expected_keys = set(["r", "alpha", "theta", "x"])
22+
assert result_keys == expected_keys
23+
24+
# Test correct output shapes are returned
25+
assert data["r"].shape == (batch_size,)
26+
assert data["alpha"].shape == (batch_size,)
27+
assert data["theta"].shape == (batch_size, 2)
28+
assert data["x"].shape == (batch_size, 2)

0 commit comments

Comments
 (0)