Skip to content

Commit

Permalink
fix minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 2, 2024
1 parent 8c67dc2 commit 8b80304
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions bayesflow/simulators/sequential_simulator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from collections.abc import Sequence
import keras

from .composite_simulator import Compose
from .composite_simulator import CompositeSimulator
from .functional_simulator import FunctionalSimulator
from .simulator import Simulator
from ..types import Shape, Tensor


class SequentialSimulator(Simulator):
def __init__(self, sample_fns: Sequence[callable], *, convert_dtypes: str = None, **kwargs):
self.inner = Compose([FunctionalSimulator(fn, **kwargs) for fn in sample_fns])
self.inner = CompositeSimulator([FunctionalSimulator(fn, **kwargs) for fn in sample_fns])
self.convert_dtypes = convert_dtypes

def sample(self, batch_shape: Shape, **kwargs) -> dict[str, Tensor]:
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/simulators/two_moons_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class TwoMoonsSimulator(Simulator):
"""TODO: Docs"""

def sample(self, batch_shape: Shape) -> dict[str, Tensor]:
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, Tensor]:
r = keras.random.normal(batch_shape + (1,), 0.1, 0.01)
alpha = keras.random.uniform(batch_shape + (1,), -0.5 * np.pi, 0.5 * np.pi)

Expand Down

0 comments on commit 8b80304

Please sign in to comment.