diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 9911d6920..5c389b329 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -11,7 +11,7 @@ from bayesflow.adapters import Adapter from bayesflow.networks import InferenceNetwork, SummaryNetwork from bayesflow.types import Tensor -from bayesflow.utils import logging, split_arrays +from bayesflow.utils import filter_kwargs, logging, split_arrays from .approximator import Approximator @@ -141,7 +141,7 @@ def sample( ) -> dict[str, np.ndarray]: conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions) - conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions)} + conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)} conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions) conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) @@ -154,6 +154,7 @@ def _sample( num_samples: int, inference_conditions: Tensor = None, summary_variables: Tensor = None, + **kwargs, ) -> Tensor: if self.summary_network is None: if summary_variables is not None: @@ -162,7 +163,9 @@ def _sample( if summary_variables is None: raise ValueError("Summary variables are required when a summary network is present.") - summary_outputs = self.summary_network(summary_variables) + summary_outputs = self.summary_network( + summary_variables, **filter_kwargs(kwargs, self.summary_network.call) + ) if inference_conditions is None: inference_conditions = summary_outputs @@ -180,18 +183,26 @@ def _sample( else: batch_shape = (num_samples,) - return self.inference_network.sample(batch_shape, conditions=inference_conditions) + return self.inference_network.sample( + batch_shape, + conditions=inference_conditions, + **filter_kwargs(kwargs, self.inference_network.sample), + ) - def log_prob(self, data: dict[str, np.ndarray]) -> np.ndarray: - data = self.adapter(data, strict=False, stage="inference") + def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray: + data = self.adapter(data, strict=False, stage="inference", **kwargs) data = keras.tree.map_structure(keras.ops.convert_to_tensor, data) - log_prob = self._log_prob(**data) + log_prob = self._log_prob(**data, **kwargs) log_prob = keras.ops.convert_to_numpy(log_prob) return log_prob def _log_prob( - self, inference_variables: Tensor, inference_conditions: Tensor = None, summary_variables: Tensor = None + self, + inference_variables: Tensor, + inference_conditions: Tensor = None, + summary_variables: Tensor = None, + **kwargs, ) -> Tensor: if self.summary_network is None: if summary_variables is not None: @@ -200,11 +211,17 @@ def _log_prob( if summary_variables is None: raise ValueError("Summary variables are required when a summary network is present.") - summary_outputs = self.summary_network(summary_variables) + summary_outputs = self.summary_network( + summary_variables, **filter_kwargs(kwargs, self.summary_network.call) + ) if inference_conditions is None: inference_conditions = summary_outputs else: inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1) - return self.inference_network.log_prob(inference_variables, conditions=inference_conditions) + return self.inference_network.log_prob( + inference_variables, + conditions=inference_conditions, + **filter_kwargs(kwargs, self.inference_network.log_prob), + )