Skip to content

Commit

Permalink
hacky fix for approximator.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 17, 2024
1 parent a14cfff commit b4afea0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
14 changes: 10 additions & 4 deletions bayesflow/approximators/base_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(
self.summary_network = summary_network
self.configurator = configurator

def sample(self, num_samples: int = 1, data: dict[str, Tensor] = None, numpy: bool = False) -> dict[str, Tensor]:
def sample(self, batch_shape: Shape, data: dict[str, Tensor] = None, numpy: bool = False) -> dict[str, Tensor]:
num_datasets, num_samples = batch_shape

if data is None:
data = {}
else:
Expand All @@ -37,10 +39,14 @@ def sample(self, num_samples: int = 1, data: dict[str, Tensor] = None, numpy: bo
data["summary_variables"] = self.configurator.configure_summary_variables(data)
data["summary_outputs"] = self.summary_network(data["summary_variables"])

conditions = self.configurator.configure_inference_conditions(data)
conditions = expand_tile(conditions, axis=1, n=num_samples)
inference_conditions = self.configurator.configure_inference_conditions(data)

# TODO: do not assume this is a tensor
# TODO: do not rely on ndim == 2 vs ndim == 3 (i.e., allow multiple feature dimensions for conditions)
if inference_conditions is not None and keras.ops.ndim(inference_conditions) == 2:
inference_conditions = expand_tile(inference_conditions, axis=1, n=num_samples)

samples = self.inference_network.sample(num_samples, conditions=conditions)
samples = self.inference_network.sample(batch_shape, conditions=inference_conditions)
samples = self.configurator.deconfigure(samples)

if self.summary_network is not None:
Expand Down
6 changes: 3 additions & 3 deletions bayesflow/networks/inference_network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import keras

from bayesflow.types import Tensor
from bayesflow.types import Shape, Tensor
from bayesflow.utils import find_distribution


Expand All @@ -26,8 +26,8 @@ def _forward(self, x: Tensor, **kwargs) -> Tensor | tuple[Tensor, Tensor]:
def _inverse(self, z: Tensor, **kwargs) -> Tensor | tuple[Tensor, Tensor]:
raise NotImplementedError

def sample(self, num_samples: int, conditions: Tensor = None, **kwargs) -> Tensor:
samples = self.base_distribution.sample((num_samples,))
def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor:
samples = self.base_distribution.sample(batch_shape)
samples = self(samples, conditions=conditions, inverse=True, density=False, **kwargs)
return samples

Expand Down

0 comments on commit b4afea0

Please sign in to comment.