Skip to content

Commit

Permalink
Added basic unit tests
Browse files Browse the repository at this point in the history
Added test_sequential_simulators to test_simulators/test_simulators.py.
Removed extra whitespace in seqential_simulator.py
  • Loading branch information
Chase-Grajeda committed Jul 1, 2024
1 parent e829619 commit b2781a5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
3 changes: 0 additions & 3 deletions bayesflow/simulators/sequential_simulator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import keras

from collections.abc import Sequence

from bayesflow.types import Sampler, Shape, Tensor
from bayesflow.utils import batched_call, filter_kwargs

from .simulator import Simulator


Expand Down
15 changes: 15 additions & 0 deletions tests/test_simulators/test_simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,18 @@ def test_sample_is_random(simulator, batch_size):
expected = keras.ops.size(tensor)
actual = np.size(np.unique(array))
assert actual == expected


def test_sequential_simulators(sequential_two_moons, batch_size):
data = sequential_two_moons.sample((batch_size,))

# Test all keys are present
result_keys = set([key for key in data.keys()])
expected_keys = set(["r", "alpha", "theta", "x"])
assert result_keys == expected_keys

# Test correct output shapes are returned
assert data["r"].shape == (batch_size,)
assert data["alpha"].shape == (batch_size,)
assert data["theta"].shape == (batch_size, 2)
assert data["x"].shape == (batch_size, 2)

0 comments on commit b2781a5

Please sign in to comment.