Skip to content

Commit

Permalink
Merge pull request #182 from Chase-Grajeda/cif-base-implementation
Browse files Browse the repository at this point in the history
CIF Implementation
  • Loading branch information
stefanradev93 authored Aug 31, 2024
2 parents 07964b3 + df94651 commit 80ed60a
Show file tree
Hide file tree
Showing 5 changed files with 424 additions and 0 deletions.
1 change: 1 addition & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cif import CIF
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/cif/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cif import CIF
91 changes: 91 additions & 0 deletions bayesflow/networks/cif/cif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import keras
from keras.saving import register_keras_serializable
from ..inference_network import InferenceNetwork
from ..coupling_flow import CouplingFlow
from .conditional_gaussian import ConditionalGaussian


@register_keras_serializable(package="bayesflow.networks")
class CIF(InferenceNetwork):
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow`
bijection and `ConditionalGaussian` distributions p and q. Improves on
eliminating leaky sampling found topologically in normalizing flows.
Bulit in reference to [1].
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021).
Relaxing Bijectivity Constraints with Continuously Indexed Normalising
Flows.
arXiv:1909.13833.
"""

def __init__(self, pq_depth=4, pq_width=128, pq_activation="tanh", **kwargs):
"""Creates an instance of a `CIF` with configurable
`ConditionalGaussian` distributions p and q, each containing MLP
networks
Parameters:
-----------
pq_depth: int, optional, default: 4
The number of MLP hidden layers (minimum: 1)
pq_width: int, optional, default: 128
The dimensionality of the MLP hidden layers
pq_activation: str, optional, default: 'tanh'
The MLP activation function
"""

super().__init__(base_distribution="normal", **kwargs)
self.bijection = CouplingFlow()
self.p_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)
self.q_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)

def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
self.bijection.build(xz_shape, conditions_shape=conditions_shape)
self.p_dist.build(xz_shape)
self.q_dist.build(xz_shape)

def call(self, xz, conditions=None, inverse=False, **kwargs):
if inverse:
return self._inverse(xz, conditions=conditions, **kwargs)
return self._forward(xz, conditions=conditions, **kwargs)

def _forward(self, x, conditions=None, density=False, **kwargs):
# Sample u ~ q_u
u, log_qu = self.q_dist.sample(x, log_prob=True)

# Bijection and log jacobian x -> z
z, log_jac = self.bijection(x, conditions=conditions, density=True)
if log_jac.ndim > 1:
log_jac = keras.ops.sum(log_jac, axis=1)

# Log prob over p on u with conditions z
log_pu = self.p_dist.log_prob(u, z)

# Prior log prob
log_prior = self.base_distribution.log_prob(z)
if log_prior.ndim > 1:
log_prior = keras.ops.sum(log_prior, axis=1)

# ELBO loss
elbo = log_jac + log_pu + log_prior - log_qu

if density:
return z, elbo
return z

def _inverse(self, z, conditions=None, density=False, **kwargs):
# Inverse bijection z -> x
u = self.p_dist.sample(z)
x = self.bijection(z, conditions=conditions, inverse=True)
if density:
log_pu = self.p_dist.log_prob(u, x)
return x, log_pu
return x

def compute_metrics(self, data, stage="training"):
base_metrics = super().compute_metrics(data, stage=stage)
inference_variables = data["inference_variables"]
inference_conditions = data.get("inference_conditions")
_, elbo = self(inference_variables, conditions=inference_conditions, inverse=False, density=True)
loss = -keras.ops.mean(elbo)
return base_metrics | {"loss": loss}
73 changes: 73 additions & 0 deletions bayesflow/networks/cif/conditional_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import keras
from keras.saving import register_keras_serializable
import numpy as np
from ..mlp import MLP
from bayesflow.utils import keras_kwargs


@register_keras_serializable(package="bayesflow.networks.cif")
class ConditionalGaussian(keras.Layer):
"""Implements a conditional gaussian distribution with neural networks for
the means and standard deviations respectively. Bulit in reference to [1].
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021).
Relaxing Bijectivity Constraints with Continuously Indexed Normalising
Flows.
arXiv:1909.13833.
"""

def __init__(self, depth=4, width=128, activation="tanh", **kwargs):
"""Creates an instance of a `ConditionalGaussian` with configurable
`MLP` networks for the means and standard deviations.
Parameters:
-----------
depth: int, optional, default: 4
The number of MLP hidden layers (minimum: 1)
width: int, optional, default: 128
The dimensionality of the MLP hidden layers
activation: str, optional, default: "tanh"
The MLP activation function
"""

super().__init__(**keras_kwargs(kwargs))
self.means = MLP(depth=depth, width=width, activation=activation)
self.stds = MLP(depth=depth, width=width, activation=activation)
self.output_projector = keras.layers.Dense(None)

def build(self, input_shape):
self.means.build(input_shape)
self.stds.build(input_shape)
self.output_projector.units = input_shape[-1]

def _diagonal_gaussian_log_prob(self, conditions, means, stds):
flat_c = keras.layers.Flatten()(conditions)
flat_means = keras.layers.Flatten()(means)
flat_vars = keras.layers.Flatten()(stds) ** 2

dim = keras.ops.shape(flat_c)[1]

const_term = -0.5 * dim * np.log(2 * np.pi)
log_det_terms = -0.5 * keras.ops.sum(keras.ops.log(flat_vars), axis=1)
product_terms = -0.5 * keras.ops.sum((flat_c - flat_means) ** 2 / flat_vars, axis=1)

return const_term + log_det_terms + product_terms

def log_prob(self, x, conditions):
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))
return self._diagonal_gaussian_log_prob(x, means, stds)

def sample(self, conditions, log_prob=False):
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))

# Reparameterize
samples = stds * keras.random.normal(keras.ops.shape(conditions)) + means

# Log probability
if log_prob:
log_p = self._diagonal_gaussian_log_prob(samples, means, stds)
return samples, log_p

return samples
258 changes: 258 additions & 0 deletions examples/moons_cif.ipynb

Large diffs are not rendered by default.

0 comments on commit 80ed60a

Please sign in to comment.