From 8b80304f9e37cf73b4a9a12e6aa2e9e5fc2787ba Mon Sep 17 00:00:00 2001 From: lars Date: Tue, 2 Jul 2024 19:02:07 +0200 Subject: [PATCH] fix minor bugs --- bayesflow/simulators/sequential_simulator.py | 4 ++-- bayesflow/simulators/two_moons_simulator.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/simulators/sequential_simulator.py b/bayesflow/simulators/sequential_simulator.py index 6446be453..cb88d6220 100644 --- a/bayesflow/simulators/sequential_simulator.py +++ b/bayesflow/simulators/sequential_simulator.py @@ -1,7 +1,7 @@ from collections.abc import Sequence import keras -from .composite_simulator import Compose +from .composite_simulator import CompositeSimulator from .functional_simulator import FunctionalSimulator from .simulator import Simulator from ..types import Shape, Tensor @@ -9,7 +9,7 @@ class SequentialSimulator(Simulator): def __init__(self, sample_fns: Sequence[callable], *, convert_dtypes: str = None, **kwargs): - self.inner = Compose([FunctionalSimulator(fn, **kwargs) for fn in sample_fns]) + self.inner = CompositeSimulator([FunctionalSimulator(fn, **kwargs) for fn in sample_fns]) self.convert_dtypes = convert_dtypes def sample(self, batch_shape: Shape, **kwargs) -> dict[str, Tensor]: diff --git a/bayesflow/simulators/two_moons_simulator.py b/bayesflow/simulators/two_moons_simulator.py index 813de3077..41e450c55 100644 --- a/bayesflow/simulators/two_moons_simulator.py +++ b/bayesflow/simulators/two_moons_simulator.py @@ -8,7 +8,7 @@ class TwoMoonsSimulator(Simulator): """TODO: Docs""" - def sample(self, batch_shape: Shape) -> dict[str, Tensor]: + def sample(self, batch_shape: Shape, **kwargs) -> dict[str, Tensor]: r = keras.random.normal(batch_shape + (1,), 0.1, 0.01) alpha = keras.random.uniform(batch_shape + (1,), -0.5 * np.pi, 0.5 * np.pi)