-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #182 from Chase-Grajeda/cif-base-implementation
CIF Implementation
- Loading branch information
Showing
5 changed files
with
424 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .cif import CIF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import keras | ||
from keras.saving import register_keras_serializable | ||
from ..inference_network import InferenceNetwork | ||
from ..coupling_flow import CouplingFlow | ||
from .conditional_gaussian import ConditionalGaussian | ||
|
||
|
||
@register_keras_serializable(package="bayesflow.networks") | ||
class CIF(InferenceNetwork): | ||
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow` | ||
bijection and `ConditionalGaussian` distributions p and q. Improves on | ||
eliminating leaky sampling found topologically in normalizing flows. | ||
Bulit in reference to [1]. | ||
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021). | ||
Relaxing Bijectivity Constraints with Continuously Indexed Normalising | ||
Flows. | ||
arXiv:1909.13833. | ||
""" | ||
|
||
def __init__(self, pq_depth=4, pq_width=128, pq_activation="tanh", **kwargs): | ||
"""Creates an instance of a `CIF` with configurable | ||
`ConditionalGaussian` distributions p and q, each containing MLP | ||
networks | ||
Parameters: | ||
----------- | ||
pq_depth: int, optional, default: 4 | ||
The number of MLP hidden layers (minimum: 1) | ||
pq_width: int, optional, default: 128 | ||
The dimensionality of the MLP hidden layers | ||
pq_activation: str, optional, default: 'tanh' | ||
The MLP activation function | ||
""" | ||
|
||
super().__init__(base_distribution="normal", **kwargs) | ||
self.bijection = CouplingFlow() | ||
self.p_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation) | ||
self.q_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation) | ||
|
||
def build(self, xz_shape, conditions_shape=None): | ||
super().build(xz_shape) | ||
self.bijection.build(xz_shape, conditions_shape=conditions_shape) | ||
self.p_dist.build(xz_shape) | ||
self.q_dist.build(xz_shape) | ||
|
||
def call(self, xz, conditions=None, inverse=False, **kwargs): | ||
if inverse: | ||
return self._inverse(xz, conditions=conditions, **kwargs) | ||
return self._forward(xz, conditions=conditions, **kwargs) | ||
|
||
def _forward(self, x, conditions=None, density=False, **kwargs): | ||
# Sample u ~ q_u | ||
u, log_qu = self.q_dist.sample(x, log_prob=True) | ||
|
||
# Bijection and log jacobian x -> z | ||
z, log_jac = self.bijection(x, conditions=conditions, density=True) | ||
if log_jac.ndim > 1: | ||
log_jac = keras.ops.sum(log_jac, axis=1) | ||
|
||
# Log prob over p on u with conditions z | ||
log_pu = self.p_dist.log_prob(u, z) | ||
|
||
# Prior log prob | ||
log_prior = self.base_distribution.log_prob(z) | ||
if log_prior.ndim > 1: | ||
log_prior = keras.ops.sum(log_prior, axis=1) | ||
|
||
# ELBO loss | ||
elbo = log_jac + log_pu + log_prior - log_qu | ||
|
||
if density: | ||
return z, elbo | ||
return z | ||
|
||
def _inverse(self, z, conditions=None, density=False, **kwargs): | ||
# Inverse bijection z -> x | ||
u = self.p_dist.sample(z) | ||
x = self.bijection(z, conditions=conditions, inverse=True) | ||
if density: | ||
log_pu = self.p_dist.log_prob(u, x) | ||
return x, log_pu | ||
return x | ||
|
||
def compute_metrics(self, data, stage="training"): | ||
base_metrics = super().compute_metrics(data, stage=stage) | ||
inference_variables = data["inference_variables"] | ||
inference_conditions = data.get("inference_conditions") | ||
_, elbo = self(inference_variables, conditions=inference_conditions, inverse=False, density=True) | ||
loss = -keras.ops.mean(elbo) | ||
return base_metrics | {"loss": loss} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import keras | ||
from keras.saving import register_keras_serializable | ||
import numpy as np | ||
from ..mlp import MLP | ||
from bayesflow.utils import keras_kwargs | ||
|
||
|
||
@register_keras_serializable(package="bayesflow.networks.cif") | ||
class ConditionalGaussian(keras.Layer): | ||
"""Implements a conditional gaussian distribution with neural networks for | ||
the means and standard deviations respectively. Bulit in reference to [1]. | ||
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021). | ||
Relaxing Bijectivity Constraints with Continuously Indexed Normalising | ||
Flows. | ||
arXiv:1909.13833. | ||
""" | ||
|
||
def __init__(self, depth=4, width=128, activation="tanh", **kwargs): | ||
"""Creates an instance of a `ConditionalGaussian` with configurable | ||
`MLP` networks for the means and standard deviations. | ||
Parameters: | ||
----------- | ||
depth: int, optional, default: 4 | ||
The number of MLP hidden layers (minimum: 1) | ||
width: int, optional, default: 128 | ||
The dimensionality of the MLP hidden layers | ||
activation: str, optional, default: "tanh" | ||
The MLP activation function | ||
""" | ||
|
||
super().__init__(**keras_kwargs(kwargs)) | ||
self.means = MLP(depth=depth, width=width, activation=activation) | ||
self.stds = MLP(depth=depth, width=width, activation=activation) | ||
self.output_projector = keras.layers.Dense(None) | ||
|
||
def build(self, input_shape): | ||
self.means.build(input_shape) | ||
self.stds.build(input_shape) | ||
self.output_projector.units = input_shape[-1] | ||
|
||
def _diagonal_gaussian_log_prob(self, conditions, means, stds): | ||
flat_c = keras.layers.Flatten()(conditions) | ||
flat_means = keras.layers.Flatten()(means) | ||
flat_vars = keras.layers.Flatten()(stds) ** 2 | ||
|
||
dim = keras.ops.shape(flat_c)[1] | ||
|
||
const_term = -0.5 * dim * np.log(2 * np.pi) | ||
log_det_terms = -0.5 * keras.ops.sum(keras.ops.log(flat_vars), axis=1) | ||
product_terms = -0.5 * keras.ops.sum((flat_c - flat_means) ** 2 / flat_vars, axis=1) | ||
|
||
return const_term + log_det_terms + product_terms | ||
|
||
def log_prob(self, x, conditions): | ||
means = self.output_projector(self.means(conditions)) | ||
stds = keras.ops.exp(self.output_projector(self.stds(conditions))) | ||
return self._diagonal_gaussian_log_prob(x, means, stds) | ||
|
||
def sample(self, conditions, log_prob=False): | ||
means = self.output_projector(self.means(conditions)) | ||
stds = keras.ops.exp(self.output_projector(self.stds(conditions))) | ||
|
||
# Reparameterize | ||
samples = stds * keras.random.normal(keras.ops.shape(conditions)) + means | ||
|
||
# Log probability | ||
if log_prob: | ||
log_p = self._diagonal_gaussian_log_prob(samples, means, stds) | ||
return samples, log_p | ||
|
||
return samples |
Large diffs are not rendered by default.
Oops, something went wrong.