diff --git a/bayesflow/approximators/base_approximator.py b/bayesflow/approximators/base_approximator.py index 8df0a67d1..352d359c1 100644 --- a/bayesflow/approximators/base_approximator.py +++ b/bayesflow/approximators/base_approximator.py @@ -27,7 +27,9 @@ def __init__( self.summary_network = summary_network self.configurator = configurator - def sample(self, num_samples: int = 1, data: dict[str, Tensor] = None, numpy: bool = False) -> dict[str, Tensor]: + def sample(self, batch_shape: Shape, data: dict[str, Tensor] = None, numpy: bool = False) -> dict[str, Tensor]: + num_datasets, num_samples = batch_shape + if data is None: data = {} else: @@ -37,10 +39,14 @@ def sample(self, num_samples: int = 1, data: dict[str, Tensor] = None, numpy: bo data["summary_variables"] = self.configurator.configure_summary_variables(data) data["summary_outputs"] = self.summary_network(data["summary_variables"]) - conditions = self.configurator.configure_inference_conditions(data) - conditions = expand_tile(conditions, axis=1, n=num_samples) + inference_conditions = self.configurator.configure_inference_conditions(data) + + # TODO: do not assume this is a tensor + # TODO: do not rely on ndim == 2 vs ndim == 3 (i.e., allow multiple feature dimensions for conditions) + if inference_conditions is not None and keras.ops.ndim(inference_conditions) == 2: + inference_conditions = expand_tile(inference_conditions, axis=1, n=num_samples) - samples = self.inference_network.sample(num_samples, conditions=conditions) + samples = self.inference_network.sample(batch_shape, conditions=inference_conditions) samples = self.configurator.deconfigure(samples) if self.summary_network is not None: diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index aa9640fee..85776f604 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,6 +1,6 @@ import keras -from bayesflow.types import Tensor +from bayesflow.types import Shape, Tensor from bayesflow.utils import find_distribution @@ -26,8 +26,8 @@ def _forward(self, x: Tensor, **kwargs) -> Tensor | tuple[Tensor, Tensor]: def _inverse(self, z: Tensor, **kwargs) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError - def sample(self, num_samples: int, conditions: Tensor = None, **kwargs) -> Tensor: - samples = self.base_distribution.sample((num_samples,)) + def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor: + samples = self.base_distribution.sample(batch_shape) samples = self(samples, conditions=conditions, inverse=True, density=False, **kwargs) return samples