Skip to content

Commit

Permalink
Completed model
Browse files Browse the repository at this point in the history
Finished model in cif.py.
Added conditional_gaussian.py for ConditionalGaussian helper class.
Removed cif.ipynb example.
Added moons_cif.ipynb example.
  • Loading branch information
Chase-Grajeda committed Aug 6, 2024
1 parent a4fa71e commit ad11d4b
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 962 deletions.
117 changes: 62 additions & 55 deletions bayesflow/networks/cif/cif.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,93 @@
import keras
from keras.saving import register_keras_serializable
from ..inference_network import InferenceNetwork
from ..coupling_flow import CouplingFlow
from .conditional_gaussian import ConditionalGaussian


@register_keras_serializable(package="bayesflow.networks")
class CIF(InferenceNetwork):
def __init__(self, **kwargs):
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow` bijection and
`ConditionalGaussian` distributions p and q. Improves on eliminating leaky
sampling found topologically in normalizing flows. Bulit in reference to [1].
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021).
Relaxing Bijectivity Constraints with Continuously Indexed Normalising Flows.
arXiv:1909.13833.
"""

def __init__(self, pq_depth=4, pq_width=128, pq_activation="tanh", **kwargs):
"""Creates an instance of a `CIF` with configurable `ConditionalGaussian` distributions
p and q, each containing MLP networks
Parameters:
-----------
pq_depth: int, optional, default: 4
The number of MLP hidden layers (minimum: 1)
pq_width: int, optional, default: 128
The dimensionality of the MLP hidden layers
pq_activation: str, optional, default: 'tanh'
The MLP activation function
"""

super().__init__(base_distribution="normal", **kwargs)
# Member variables wrt to nux implementation
self.feature_net = CouplingFlow() # no conditions
self.flow = CouplingFlow() # bijective transformer
self.u_dist = self.base_distribution # Gaussian prior
self.v_dist = CouplingFlow() # conditioned flow / parameterized gaussian
self.bijection = CouplingFlow()
self.p_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)
self.q_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)


def build(self, xz_shape, conditions_shape):
super().build(xz_shape)
self.feature_net.build(xz_shape)
self.flow.build(xz_shape, xz_shape)
self.v_dist.build(xz_shape, xz_shape)
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
self.bijection.build(xz_shape, conditions_shape=conditions_shape)
self.p_dist.build(xz_shape)
self.q_dist.build(xz_shape)


def call(self, xz, conditions, inverse=False, **kwargs):
def call(self, xz, conditions=None, inverse=False, **kwargs):
if inverse:
return self._inverse(xz, conditions, **kwargs)
return self._forward(xz, conditions, **kwargs)
return self._inverse(xz, conditions=conditions, **kwargs)
return self._forward(xz, conditions=conditions, **kwargs)


def _forward(self, x, conditions, density=False, **kwargs):
# NOTE: conditions should be used...

# Sample u ~ q(u|phi_x)
phi_x = self.feature_net(x, conditions=None)
u, log_qu = self.v_dist(keras.ops.zeros_like(x), conditions=phi_x, inverse=True, density=True)
def _forward(self, x, conditions=None, density=False, **kwargs):
# Sample u ~ q_u
u, log_qu = self.q_dist.sample(x, log_prob=True)

# Bijection and log jacobian x -> z
z, log_jac = self.bijection(x, conditions=conditions, density=True)
if log_jac.ndim > 1:
log_jac = keras.ops.sum(log_jac, axis=1)

# Compute z = f(x; phi_u) and p(x|u)
phi_u = self.feature_net(u, conditions=None)
z, log_px = self.flow(x, conditions=phi_u, inverse=False, density=True)
# Log prob over p on u with conditions z
log_pu = self.p_dist.log_prob(u, z)

# Compute p(u)
log_pu = self.base_distribution.log_prob(u)
# Prior log prob
log_prior = self.base_distribution.log_prob(z)
if log_prior.ndim > 1:
log_prior = keras.ops.sum(log_prior, axis=1)

# Log likelihood?
llc = log_px + log_pu - log_qu
# ELBO loss
elbo = log_jac + log_pu + log_prior - log_qu

# NOTE - this can be moved up when I'm done tinkering
if density:
return z, llc
return z, elbo
return z


def _inverse(self, z, conditions, density=False, **kwargs):
# NOTE: conditions should be used...

# Sample u ~ p(u)
u = self.base_distribution.sample(keras.ops.shape(z)[:-1])
log_pu = self.base_distribution.log_prob(keras.ops.zeros_like(z))

# Compute inverse of f(z; u)
phi_u = self.feature_net(u)
x, log_px = self.flow(z, conditions=phi_u, inverse=True, density=True)

# Predict q(u|x)
phi_x = self.feature_net(x)
_, log_qu = self.v_dist(u, conditions=phi_x, inverse=False, density=True)

# Log likelihood?
llc = log_px + log_pu - log_qu

# NOTE: this can be moved up when I'm done tinkering
def _inverse(self, z, conditions=None, density=False, **kwargs):
# Inverse bijection z -> x
u = self.p_dist.sample(z)
x = self.bijection(z, conditions=conditions, inverse=True)
if density:
return x, llc
log_pu = self.p_dist.log_prob(u, x)
return x, log_pu
return x


def compute_metrics(self, data, stage="training"):
base_metrics = super().compute_metrics(data, stage=stage)
inference_variables = data["inference_variables"]
inference_conditions = data.get("inference_conditions")

z, log_density = self(inference_variables, conditions=inference_conditions, inverse=False, density=True)
# Should loss be reduced this way..?
loss = -keras.ops.mean(log_density)
_, elbo = self(inference_variables, conditions=inference_conditions, inverse=False, density=True)
loss = -keras.ops.mean(elbo)
return base_metrics | {"loss": loss}


76 changes: 76 additions & 0 deletions bayesflow/networks/cif/conditional_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import keras
from keras.saving import register_keras_serializable
import numpy as np
from ..mlp import MLP
from bayesflow.utils import keras_kwargs


@register_keras_serializable(package="bayesflow.networks.cif")
class ConditionalGaussian(keras.Layer):
"""Implements a conditional gaussian distribution with neural networks for the
means and standard deviations respectively. Bulit in reference to [1].
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021).
Relaxing Bijectivity Constraints with Continuously Indexed Normalising Flows.
arXiv:1909.13833.
"""

def __init__(self, depth=4, width=128, activation="tanh", **kwargs):
"""Creates an instance of a `ConditionalGaussian` with configurable `MLP`
networks for the means and standard deviations.
Parameters:
-----------
depth: int, optional, default: 4
The number of MLP hidden layers (minimum: 1)
width: int, optional, default: 128
The dimensionality of the MLP hidden layers
activation: str, optional, default: "tanh"
The MLP activation function
"""

super().__init__(**keras_kwargs(kwargs))
self.means = MLP(depth=depth, width=width, activation=activation)
self.stds = MLP(depth=depth, width=width, activation=activation)
self.output_projector = keras.layers.Dense(None)


def build(self, input_shape):
self.means.build(input_shape)
self.stds.build(input_shape)
self.output_projector.units = input_shape[-1]


def _diagonal_gaussian_log_prob(self, conditions, means, stds):
flat_c = keras.layers.Flatten()(conditions)
flat_means = keras.layers.Flatten()(means)
flat_vars = keras.layers.Flatten()(stds) ** 2

dim = keras.ops.shape(flat_c)[1]

const_term = -0.5 * dim * np.log(2 * np.pi)
log_det_terms = -0.5 * keras.ops.sum(keras.ops.log(flat_vars), axis=1)
product_terms = -0.5 * keras.ops.sum((flat_c - flat_means) ** 2 / flat_vars, axis=1)

return const_term + log_det_terms + product_terms


def log_prob(self, x, conditions):
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))
return self._diagonal_gaussian_log_prob(x, means, stds)


def sample(self, conditions, log_prob=False):
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))

# Reparameterize
samples = stds * keras.random.normal(keras.ops.shape(conditions)) + means

# Log probability
if log_prob:
log_p = self._diagonal_gaussian_log_prob(samples, means, stds)
return samples, log_p

return samples
Loading

0 comments on commit ad11d4b

Please sign in to comment.