Skip to content

Commit

Permalink
add type hints and do some general clean-up (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue authored Sep 24, 2024
1 parent 0b2663a commit 2d154b4
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 38 deletions.
56 changes: 35 additions & 21 deletions bayesflow/networks/cif/cif.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
import keras
from keras.saving import register_keras_serializable
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor

from ..inference_network import InferenceNetwork
from ..coupling_flow import CouplingFlow

from .conditional_gaussian import ConditionalGaussian


@register_keras_serializable(package="bayesflow.networks")
@serializable(package="bayesflow.networks")
class CIF(InferenceNetwork):
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow`
bijection and `ConditionalGaussian` distributions p and q. Improves on
eliminating leaky sampling found topologically in normalizing flows.
Bulit in reference to [1].
Built in reference to [1].
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021).
Relaxing Bijectivity Constraints with Continuously Indexed Normalising
Flows.
arXiv:1909.13833.
"""

def __init__(self, pq_depth=4, pq_width=128, pq_activation="tanh", **kwargs):
def __init__(self, pq_depth: int = 4, pq_width: int = 128, pq_activation: str = "swish", **kwargs):
"""Creates an instance of a `CIF` with configurable
`ConditionalGaussian` distributions p and q, each containing MLP
networks
Expand All @@ -38,22 +42,26 @@ def __init__(self, pq_depth=4, pq_width=128, pq_activation="tanh", **kwargs):
self.p_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)
self.q_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)

def build(self, xz_shape, conditions_shape=None):
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
super().build(xz_shape)
self.bijection.build(xz_shape, conditions_shape=conditions_shape)
self.p_dist.build(xz_shape)
self.q_dist.build(xz_shape)

def call(self, xz, conditions=None, inverse=False, **kwargs):
def call(
self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if inverse:
return self._inverse(xz, conditions=conditions, **kwargs)
return self._forward(xz, conditions=conditions, **kwargs)

def _forward(self, x, conditions=None, density=False, **kwargs):
def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
# Sample u ~ q_u
u, log_qu = self.q_dist.sample(x, log_prob=True)

# Bijection and log jacobian x -> z
# Bijection and log Jacobian x -> z
z, log_jac = self.bijection(x, conditions=conditions, density=True)
if log_jac.ndim > 1:
log_jac = keras.ops.sum(log_jac, axis=1)
Expand All @@ -66,26 +74,32 @@ def _forward(self, x, conditions=None, density=False, **kwargs):
if log_prior.ndim > 1:
log_prior = keras.ops.sum(log_prior, axis=1)

# ELBO loss
# we cannot compute an exact analytical density
elbo = log_jac + log_pu + log_prior - log_qu

if density:
return z, elbo

return z

def _inverse(self, z, conditions=None, density=False, **kwargs):
# Inverse bijection z -> x
def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if not density:
return self.bijection(z, conditions=conditions, inverse=True, density=False)

u = self.p_dist.sample(z)
x = self.bijection(z, conditions=conditions, inverse=True)
if density:
log_pu = self.p_dist.log_prob(u, x)
return x, log_pu
return x

def compute_metrics(self, data, stage="training"):
base_metrics = super().compute_metrics(data, stage=stage)
inference_variables = data["inference_variables"]
inference_conditions = data.get("inference_conditions")
_, elbo = self(inference_variables, conditions=inference_conditions, inverse=False, density=True)

log_pu = self.p_dist.log_prob(u, x)

return x, log_pu

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)

elbo = self.log_prob(x, conditions=conditions)

loss = -keras.ops.mean(elbo)

return base_metrics | {"loss": loss}
36 changes: 22 additions & 14 deletions bayesflow/networks/cif/conditional_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from keras.saving import register_keras_serializable
import numpy as np
from ..mlp import MLP

from bayesflow.types import Shape, Tensor
from bayesflow.utils import keras_kwargs


Expand All @@ -16,7 +18,7 @@ class ConditionalGaussian(keras.Layer):
arXiv:1909.13833.
"""

def __init__(self, depth=4, width=128, activation="tanh", **kwargs):
def __init__(self, depth: int = 4, width: int = 128, activation: str = "swish", **kwargs):
"""Creates an instance of a `ConditionalGaussian` with configurable
`MLP` networks for the means and standard deviations.
Expand All @@ -26,7 +28,7 @@ def __init__(self, depth=4, width=128, activation="tanh", **kwargs):
The number of MLP hidden layers (minimum: 1)
width: int, optional, default: 128
The dimensionality of the MLP hidden layers
activation: str, optional, default: "tanh"
activation: str, optional, default: "swish"
The MLP activation function
"""

Expand All @@ -35,37 +37,43 @@ def __init__(self, depth=4, width=128, activation="tanh", **kwargs):
self.stds = MLP(depth=depth, width=width, activation=activation)
self.output_projector = keras.layers.Dense(None)

def build(self, input_shape):
def build(self, input_shape: Shape) -> None:
self.means.build(input_shape)
self.stds.build(input_shape)
self.output_projector.units = input_shape[-1]

def _diagonal_gaussian_log_prob(self, conditions, means, stds):
flat_c = keras.layers.Flatten()(conditions)
flat_means = keras.layers.Flatten()(means)
flat_vars = keras.layers.Flatten()(stds) ** 2
def _diagonal_gaussian_log_prob(self, conditions: Tensor, means: Tensor, stds: Tensor) -> Tensor:
batch_size = keras.ops.shape(conditions)[0]

if keras.ops.shape(means)[0] != batch_size or keras.ops.shape(stds)[0] != batch_size:
raise ValueError("Means and stds must have the same batch size as conditions.")

flat_conditions = keras.ops.reshape(conditions, (batch_size, -1))
flat_means = keras.ops.reshape(means, (batch_size, -1))
flat_stds = keras.ops.reshape(stds, (batch_size, -1))

flat_variances = flat_stds**2

dim = keras.ops.shape(flat_c)[1]
dim = keras.ops.shape(flat_conditions)[1]

const_term = -0.5 * dim * np.log(2 * np.pi)
log_det_terms = -0.5 * keras.ops.sum(keras.ops.log(flat_vars), axis=1)
product_terms = -0.5 * keras.ops.sum((flat_c - flat_means) ** 2 / flat_vars, axis=1)
log_det_terms = -0.5 * keras.ops.sum(keras.ops.log(flat_variances), axis=1)
product_terms = -0.5 * keras.ops.sum((flat_conditions - flat_means) ** 2 / flat_variances, axis=1)

return const_term + log_det_terms + product_terms

def log_prob(self, x, conditions):
def log_prob(self, x: Tensor, conditions: Tensor) -> Tensor:
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))
return self._diagonal_gaussian_log_prob(x, means, stds)

def sample(self, conditions, log_prob=False):
def sample(self, conditions: Tensor, log_prob: bool = False) -> Tensor | tuple[Tensor, Tensor]:
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))

# Reparameterize
# re-parametrize
samples = stds * keras.random.normal(keras.ops.shape(conditions)) + means

# Log probability
if log_prob:
log_p = self._diagonal_gaussian_log_prob(samples, means, stds)
return samples, log_p
Expand Down
12 changes: 9 additions & 3 deletions bayesflow/networks/inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@ def __init__(self, base_distribution: str = "normal", **kwargs):
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
self.base_distribution.build(xz_shape)

def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Tensor | tuple[Tensor, Tensor]:
def call(
self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, density: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if inverse:
return self._inverse(xz, **kwargs)
return self._forward(xz, **kwargs)

def _forward(self, x: Tensor, **kwargs) -> Tensor | tuple[Tensor, Tensor]:
def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
raise NotImplementedError

def _inverse(self, z: Tensor, **kwargs) -> Tensor | tuple[Tensor, Tensor]:
def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
raise NotImplementedError

def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor:
Expand Down

0 comments on commit 2d154b4

Please sign in to comment.