From f566250bc7e75f15a555a736342e6bf3350f1d5b Mon Sep 17 00:00:00 2001 From: Radev Date: Fri, 24 May 2024 12:42:50 -0400 Subject: [PATCH] Fix sampling and get rid of tensorflow_probability for default Gaussians --- .../networks/coupling_flow/coupling_flow.py | 4 ++- .../simulation/distributions/__init__.py | 15 ++-------- .../distributions/spherical_gaussian.py | 30 +++++++++++++++++++ 3 files changed, 35 insertions(+), 14 deletions(-) create mode 100644 bayesflow/experimental/simulation/distributions/spherical_gaussian.py diff --git a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py index e41924559..d65f3f3f2 100644 --- a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py @@ -167,7 +167,9 @@ def inverse(self, latents, conditions=None) -> (Tensor, Tensor): return targets, log_det - def sample(self, batch_shape: Shape, conditions=None) -> Tensor: + def sample(self, batch_shape: Shape | int, conditions=None) -> Tensor: + if type(batch_shape) is int: + batch_shape = (batch_shape, ) latents = self.base_distribution.sample(batch_shape) targets, _ = self.inverse(latents, conditions) diff --git a/bayesflow/experimental/simulation/distributions/__init__.py b/bayesflow/experimental/simulation/distributions/__init__.py index 3f31e2b36..66458ff89 100644 --- a/bayesflow/experimental/simulation/distributions/__init__.py +++ b/bayesflow/experimental/simulation/distributions/__init__.py @@ -1,8 +1,8 @@ -import keras from bayesflow.experimental.types import Distribution, Shape from .joint_distribution import JointDistribution +from .spherical_gaussian import SphericalGaussian def find_distribution(distribution: str | Distribution | type(Distribution), shape: Shape) -> Distribution: @@ -10,20 +10,9 @@ def find_distribution(distribution: str | Distribution | type(Distribution), sha return distribution if isinstance(distribution, type): return Distribution() - match distribution: case "normal": - match keras.backend.backend(): - case "jax" | "tensorflow": - import tensorflow as tf - import tensorflow_probability as tfp - distribution = tfp.distributions.Normal(tf.zeros(shape), tf.ones(shape)) - distribution = tfp.distributions.Independent(distribution, 1) - case "torch": - import torch - import torch.distributions as D - distribution = D.Normal(torch.zeros(shape), torch.ones(shape)) - distribution = D.Independent(distribution, 1) + distribution = SphericalGaussian(shape) case str() as unknown_distribution: raise ValueError(f"Distribution '{unknown_distribution}' is unknown or not yet supported by name.") case other: diff --git a/bayesflow/experimental/simulation/distributions/spherical_gaussian.py b/bayesflow/experimental/simulation/distributions/spherical_gaussian.py new file mode 100644 index 000000000..5fb897b56 --- /dev/null +++ b/bayesflow/experimental/simulation/distributions/spherical_gaussian.py @@ -0,0 +1,30 @@ + +import math + +import keras +from keras import ops + +from bayesflow.experimental.types import Shape, Distribution, Tensor + + +class SphericalGaussian(Distribution): + """Utility class for a backend-agnostic spherical Gaussian distribution. + + Note: + - ``log_unnormalized_pdf`` method is used as a loss function + - ``log_pdf`` is used for density computation + """ + def __init__(self, shape: Shape): + self.shape = shape + self.dim = int(self.shape[0]) + self._norm_const = 0.5 * self.dim * math.log(2.0 * math.pi) + + def sample(self, batch_shape: Shape): + return keras.random.normal(shape=batch_shape + self.shape, mean=0.0, stddev=1.0) + + def log_unnormalized_prob(self, tensor: Tensor): + return -0.5 * ops.sum(ops.square(tensor), axis=-1) + + def log_prob(self, tensor: Tensor): + log_unnorm_pdf = self.log_unnormalized_prob(tensor) + return log_unnorm_pdf - self._norm_const