-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Finished model in cif.py. Added conditional_gaussian.py for ConditionalGaussian helper class. Removed cif.ipynb example. Added moons_cif.ipynb example.
- Loading branch information
1 parent
a4fa71e
commit ad11d4b
Showing
4 changed files
with
396 additions
and
962 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,93 @@ | ||
import keras | ||
from keras.saving import register_keras_serializable | ||
from ..inference_network import InferenceNetwork | ||
from ..coupling_flow import CouplingFlow | ||
from .conditional_gaussian import ConditionalGaussian | ||
|
||
|
||
@register_keras_serializable(package="bayesflow.networks") | ||
class CIF(InferenceNetwork): | ||
def __init__(self, **kwargs): | ||
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow` bijection and | ||
`ConditionalGaussian` distributions p and q. Improves on eliminating leaky | ||
sampling found topologically in normalizing flows. Bulit in reference to [1]. | ||
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021). | ||
Relaxing Bijectivity Constraints with Continuously Indexed Normalising Flows. | ||
arXiv:1909.13833. | ||
""" | ||
|
||
def __init__(self, pq_depth=4, pq_width=128, pq_activation="tanh", **kwargs): | ||
"""Creates an instance of a `CIF` with configurable `ConditionalGaussian` distributions | ||
p and q, each containing MLP networks | ||
Parameters: | ||
----------- | ||
pq_depth: int, optional, default: 4 | ||
The number of MLP hidden layers (minimum: 1) | ||
pq_width: int, optional, default: 128 | ||
The dimensionality of the MLP hidden layers | ||
pq_activation: str, optional, default: 'tanh' | ||
The MLP activation function | ||
""" | ||
|
||
super().__init__(base_distribution="normal", **kwargs) | ||
# Member variables wrt to nux implementation | ||
self.feature_net = CouplingFlow() # no conditions | ||
self.flow = CouplingFlow() # bijective transformer | ||
self.u_dist = self.base_distribution # Gaussian prior | ||
self.v_dist = CouplingFlow() # conditioned flow / parameterized gaussian | ||
self.bijection = CouplingFlow() | ||
self.p_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation) | ||
self.q_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation) | ||
|
||
|
||
def build(self, xz_shape, conditions_shape): | ||
super().build(xz_shape) | ||
self.feature_net.build(xz_shape) | ||
self.flow.build(xz_shape, xz_shape) | ||
self.v_dist.build(xz_shape, xz_shape) | ||
def build(self, xz_shape, conditions_shape=None): | ||
super().build(xz_shape) | ||
self.bijection.build(xz_shape, conditions_shape=conditions_shape) | ||
self.p_dist.build(xz_shape) | ||
self.q_dist.build(xz_shape) | ||
|
||
|
||
def call(self, xz, conditions, inverse=False, **kwargs): | ||
def call(self, xz, conditions=None, inverse=False, **kwargs): | ||
if inverse: | ||
return self._inverse(xz, conditions, **kwargs) | ||
return self._forward(xz, conditions, **kwargs) | ||
return self._inverse(xz, conditions=conditions, **kwargs) | ||
return self._forward(xz, conditions=conditions, **kwargs) | ||
|
||
|
||
def _forward(self, x, conditions, density=False, **kwargs): | ||
# NOTE: conditions should be used... | ||
|
||
# Sample u ~ q(u|phi_x) | ||
phi_x = self.feature_net(x, conditions=None) | ||
u, log_qu = self.v_dist(keras.ops.zeros_like(x), conditions=phi_x, inverse=True, density=True) | ||
def _forward(self, x, conditions=None, density=False, **kwargs): | ||
# Sample u ~ q_u | ||
u, log_qu = self.q_dist.sample(x, log_prob=True) | ||
|
||
# Bijection and log jacobian x -> z | ||
z, log_jac = self.bijection(x, conditions=conditions, density=True) | ||
if log_jac.ndim > 1: | ||
log_jac = keras.ops.sum(log_jac, axis=1) | ||
|
||
# Compute z = f(x; phi_u) and p(x|u) | ||
phi_u = self.feature_net(u, conditions=None) | ||
z, log_px = self.flow(x, conditions=phi_u, inverse=False, density=True) | ||
# Log prob over p on u with conditions z | ||
log_pu = self.p_dist.log_prob(u, z) | ||
|
||
# Compute p(u) | ||
log_pu = self.base_distribution.log_prob(u) | ||
# Prior log prob | ||
log_prior = self.base_distribution.log_prob(z) | ||
if log_prior.ndim > 1: | ||
log_prior = keras.ops.sum(log_prior, axis=1) | ||
|
||
# Log likelihood? | ||
llc = log_px + log_pu - log_qu | ||
# ELBO loss | ||
elbo = log_jac + log_pu + log_prior - log_qu | ||
|
||
# NOTE - this can be moved up when I'm done tinkering | ||
if density: | ||
return z, llc | ||
return z, elbo | ||
return z | ||
|
||
|
||
def _inverse(self, z, conditions, density=False, **kwargs): | ||
# NOTE: conditions should be used... | ||
|
||
# Sample u ~ p(u) | ||
u = self.base_distribution.sample(keras.ops.shape(z)[:-1]) | ||
log_pu = self.base_distribution.log_prob(keras.ops.zeros_like(z)) | ||
|
||
# Compute inverse of f(z; u) | ||
phi_u = self.feature_net(u) | ||
x, log_px = self.flow(z, conditions=phi_u, inverse=True, density=True) | ||
|
||
# Predict q(u|x) | ||
phi_x = self.feature_net(x) | ||
_, log_qu = self.v_dist(u, conditions=phi_x, inverse=False, density=True) | ||
|
||
# Log likelihood? | ||
llc = log_px + log_pu - log_qu | ||
|
||
# NOTE: this can be moved up when I'm done tinkering | ||
def _inverse(self, z, conditions=None, density=False, **kwargs): | ||
# Inverse bijection z -> x | ||
u = self.p_dist.sample(z) | ||
x = self.bijection(z, conditions=conditions, inverse=True) | ||
if density: | ||
return x, llc | ||
log_pu = self.p_dist.log_prob(u, x) | ||
return x, log_pu | ||
return x | ||
|
||
|
||
def compute_metrics(self, data, stage="training"): | ||
base_metrics = super().compute_metrics(data, stage=stage) | ||
inference_variables = data["inference_variables"] | ||
inference_conditions = data.get("inference_conditions") | ||
|
||
z, log_density = self(inference_variables, conditions=inference_conditions, inverse=False, density=True) | ||
# Should loss be reduced this way..? | ||
loss = -keras.ops.mean(log_density) | ||
_, elbo = self(inference_variables, conditions=inference_conditions, inverse=False, density=True) | ||
loss = -keras.ops.mean(elbo) | ||
return base_metrics | {"loss": loss} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import keras | ||
from keras.saving import register_keras_serializable | ||
import numpy as np | ||
from ..mlp import MLP | ||
from bayesflow.utils import keras_kwargs | ||
|
||
|
||
@register_keras_serializable(package="bayesflow.networks.cif") | ||
class ConditionalGaussian(keras.Layer): | ||
"""Implements a conditional gaussian distribution with neural networks for the | ||
means and standard deviations respectively. Bulit in reference to [1]. | ||
[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021). | ||
Relaxing Bijectivity Constraints with Continuously Indexed Normalising Flows. | ||
arXiv:1909.13833. | ||
""" | ||
|
||
def __init__(self, depth=4, width=128, activation="tanh", **kwargs): | ||
"""Creates an instance of a `ConditionalGaussian` with configurable `MLP` | ||
networks for the means and standard deviations. | ||
Parameters: | ||
----------- | ||
depth: int, optional, default: 4 | ||
The number of MLP hidden layers (minimum: 1) | ||
width: int, optional, default: 128 | ||
The dimensionality of the MLP hidden layers | ||
activation: str, optional, default: "tanh" | ||
The MLP activation function | ||
""" | ||
|
||
super().__init__(**keras_kwargs(kwargs)) | ||
self.means = MLP(depth=depth, width=width, activation=activation) | ||
self.stds = MLP(depth=depth, width=width, activation=activation) | ||
self.output_projector = keras.layers.Dense(None) | ||
|
||
|
||
def build(self, input_shape): | ||
self.means.build(input_shape) | ||
self.stds.build(input_shape) | ||
self.output_projector.units = input_shape[-1] | ||
|
||
|
||
def _diagonal_gaussian_log_prob(self, conditions, means, stds): | ||
flat_c = keras.layers.Flatten()(conditions) | ||
flat_means = keras.layers.Flatten()(means) | ||
flat_vars = keras.layers.Flatten()(stds) ** 2 | ||
|
||
dim = keras.ops.shape(flat_c)[1] | ||
|
||
const_term = -0.5 * dim * np.log(2 * np.pi) | ||
log_det_terms = -0.5 * keras.ops.sum(keras.ops.log(flat_vars), axis=1) | ||
product_terms = -0.5 * keras.ops.sum((flat_c - flat_means) ** 2 / flat_vars, axis=1) | ||
|
||
return const_term + log_det_terms + product_terms | ||
|
||
|
||
def log_prob(self, x, conditions): | ||
means = self.output_projector(self.means(conditions)) | ||
stds = keras.ops.exp(self.output_projector(self.stds(conditions))) | ||
return self._diagonal_gaussian_log_prob(x, means, stds) | ||
|
||
|
||
def sample(self, conditions, log_prob=False): | ||
means = self.output_projector(self.means(conditions)) | ||
stds = keras.ops.exp(self.output_projector(self.stds(conditions))) | ||
|
||
# Reparameterize | ||
samples = stds * keras.random.normal(keras.ops.shape(conditions)) + means | ||
|
||
# Log probability | ||
if log_prob: | ||
log_p = self._diagonal_gaussian_log_prob(samples, means, stds) | ||
return samples, log_p | ||
|
||
return samples |
Oops, something went wrong.