Skip to content

Commit

Permalink
Working version of couplings with residual internal networks
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed May 14, 2024
1 parent 7597e45 commit 525e7de
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from bayesflow.experimental.simulation import Distribution, find_distribution
from bayesflow.experimental.types import Shape, Tensor
from .couplings import AllInOneCoupling
from .subnets import find_subnet
from .transforms import find_transform


class CouplingFlow(keras.Sequential):
Expand All @@ -21,7 +19,7 @@ def all_in_one(
cls,
target_dim: int,
num_layers: int,
subnet="default",
subnet_builder="default",
transform="affine",
permutation="fixed",
act_norm=True,
Expand All @@ -30,13 +28,11 @@ def all_in_one(
) -> "CouplingFlow":
""" Construct a uniform coupling flow, consisting of dual couplings with a single type of transform. """

subnet = find_subnet(subnet, transform, target_dim, **kwargs.pop("subnet_kwargs", {}))
transform = find_transform(transform)
base_distribution = find_distribution(base_distribution, shape=(target_dim,))

couplings = []
for _ in range(num_layers):
layer = AllInOneCoupling(subnet, target_dim, transform, permutation, act_norm)
layer = AllInOneCoupling(subnet_builder, target_dim, transform, permutation, act_norm, **kwargs)
couplings.append(layer)

return cls(couplings, base_distribution)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ class AllInOneCoupling(keras.Layer):

def __init__(
self,
subnet: keras.Model | keras.layers.Layer,
subnet_builder: str,
target_dim: int,
transform: Transform,
transform: str,
permutation: str,
act_norm: bool,
**kwargs
):
super().__init__()
self.dual_coupling = DualCoupling(subnet, target_dim, transform)
self.dual_coupling = DualCoupling(subnet_builder, target_dim, transform, **kwargs)

if permutation == "fixed":
self.permutation = FixedPermutation.swap(target_dim)
Expand Down
46 changes: 27 additions & 19 deletions bayesflow/experimental/networks/coupling_flow/couplings/coupling.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@

import keras
from keras import ops

from ..transforms import Transform
from ..transforms import Transform, find_transform
from ..subnets import find_subnet


class Coupling(keras.Layer):
""" Implements a single coupling layer, followed by a permutation. """
""" Implements a single coupling layer that transforms half of its input through a coupling transform."""
def __init__(
self,
subnet: keras.Model | keras.layers.Layer,
target_dim: int,
transform: Transform,
subnet_builder: str,
half_dim: int,
transform: str,
**kwargs
):
super().__init__(**kwargs)

self.dim = target_dim
self.subnet = subnet
self.transform = transform
self.half_dim = half_dim
self.subnet = find_subnet(
subnet=subnet_builder,
transform=transform,
output_dim=half_dim,
**kwargs.pop('subnet_kwargs', {})
)
self.transform = find_transform(transform)

def call(self, x, c=None, forward=True, **kwargs):
if forward:
Expand All @@ -26,26 +33,27 @@ def call(self, x, c=None, forward=True, **kwargs):

def forward(self, x, c=None, **kwargs):

x1, x2 = keras.ops.split(x, 2, axis=-1)
z1 = x1
parameters = self.get_parameters(x1, c, **kwargs)
z2, log_det = self.transform.forward(x2, parameters)
z = keras.ops.concatenate([z1, z2], axis=-1)
x1, x2 = x[..., :self.half_dim], x[..., self.half_dim:]
z2 = x2
parameters = self.get_parameters(x2, c, **kwargs)
z1, log_det = self.transform.forward(x1, parameters)
z = ops.concatenate([z1, z2], axis=-1)
return z, log_det

def inverse(self, z, c=None):
z1, z2 = keras.ops.split(z, 2, axis=-1)
x1 = z1
parameters = self.get_parameters(x1, c)
x2, log_det = self.transform.inverse(z2, parameters)
x = keras.ops.concatenate([x1, x2], axis=-1)
z1, z2 = z[..., :self.half_dim], z[..., self.half_dim:]
x2 = z2
parameters = self.get_parameters(x2, c)
x1, log_det = self.transform.inverse(z1, parameters)
x = ops.concatenate([x1, x2], axis=-1)
return x, log_det

def get_parameters(self, x, c=None, **kwargs):
if c is not None:
x = keras.ops.concatenate([x, c], axis=-1)
x = ops.concatenate([x, c], axis=-1)

parameters = self.subnet(x, **kwargs)
parameters = self.transform.split_parameters(parameters)
parameters = self.transform.constrain_parameters(parameters)

return parameters
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,29 @@

from bayesflow.experimental.types import Tensor
from .coupling import Coupling
from ..transforms import Transform


class DualCoupling(keras.Layer):
def __init__(
self,
subnet: keras.Model | keras.layers.Layer,
subnet_builder: str,
target_dim: int,
transform: Transform
transform: str,
**kwargs
):
super().__init__()

self.coupling1 = Coupling(
subnet=subnet,
target_dim=math.floor(target_dim / 2),
subnet_builder=subnet_builder,
half_dim=math.floor(target_dim / 2),
transform=transform,
**kwargs
)
self.coupling2 = Coupling(
subnet=subnet,
target_dim=math.ceil(target_dim / 2),
subnet_builder=subnet_builder,
half_dim=math.ceil(target_dim / 2),
transform=transform,
**kwargs
)

def call(self, x: Tensor, c=None, forward=True, **kwargs) -> (Tensor, Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...resnet.residual_block import ConditionalResidualBlock


def find_subnet(subnet: str | Callable, transform: str, target_dim: int, **kwargs):
def find_subnet(subnet: str | Callable, transform: str, output_dim: int, **kwargs):

match subnet:
case str() as name:
Expand All @@ -25,7 +25,7 @@ def find_subnet(subnet: str | Callable, transform: str, target_dim: int, **kwarg
case str() as name:
match name.lower():
case "affine":
output_dim = target_dim * 2
output_dim = output_dim * 2
case other:
raise NotImplementedError(f"Unsupported transform name: '{other}'.")
case other:
Expand Down
17 changes: 14 additions & 3 deletions bayesflow/experimental/networks/resnet/hidden_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import keras


class ConfigurableHiddenBlock(keras.Layer):
class ConfigurableHiddenBlock(keras.layers.Layer):
def __init__(
self,
num_units,
Expand All @@ -16,15 +16,26 @@ def __init__(

self.activation_fn = keras.activations.get(activation)
self.residual = residual
self.spectral_norm = spectral_norm
self.dense_with_dropout = keras.Sequential()

if spectral_norm:
self.dense_with_dropout.add(keras.layers.SpectralNormalization(keras.layers.Dense(num_units)))
else:
self.dense_with_dropout.add(keras.layers.Dense(num_units))
self.dense_with_dropout.add(keras.layers.Dropout(dropout_rate))

def call(self, inputs, **kwargs):
x = self.dense_with_dropout(inputs, **kwargs)
def call(self, inputs, training=False):
x = self.dense_with_dropout(inputs, training=training)
if self.residual:
x = x + inputs
return self.activation_fn(x)

def get_config(self):
config = super().get_config()
config.update({
"units": self.activation_fn,
"residual": self.residual,
"spectral_norm": self.spectral_norm
})
return config
19 changes: 10 additions & 9 deletions bayesflow/experimental/networks/resnet/residual_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .hidden_block import ConfigurableHiddenBlock


class ConditionalResidualBlock(keras.Layer):
class ConditionalResidualBlock(keras.layers.Layer):
"""
Implements a simple configurable MLP with optional residual connections and dropout.
Expand All @@ -22,7 +22,7 @@ def __init__(
spectral_norm=False,
dropout_rate=0.05,
zero_output_init=True,
**kwargs,
**kwargs
):
"""
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
Expand All @@ -31,15 +31,15 @@ def __init__(
-----------
output_dim : int
The output dimensionality, needs to be specified according to the model's function.
hidden_dim : int, optional, default: 512
hidden_dim : int, optional, default: 256
The dimensionality of the hidden layers
num_hidden : int, optional, default: 2
The number of hidden layers (minimum: 1)
activation : string, optional, default: 'gelu'
The activation function of the dense layers
residual : bool, optional, default: True
Use residual connections in the MLP.
spectral_norm : bool, optional, default: True
spectral_norm : bool, optional, default: False
Use spectral normalization for the network weights, which can make
the learned function smoother and hence more robust to perturbations.
dropout_rate : float, optional, default: 0.05
Expand All @@ -52,9 +52,11 @@ def __init__(
super().__init__(**kwargs)

self.dim = output_dim
self.model = keras.Sequential()
self.res_blocks = keras.Sequential(
[keras.layers.Dense(hidden_dim, activation=activation), keras.layers.Dropout(dropout_rate)]
)
for _ in range(num_hidden):
self.model.add(
self.res_blocks.add(
ConfigurableHiddenBlock(
num_units=hidden_dim,
activation=activation,
Expand All @@ -69,7 +71,6 @@ def __init__(
output_initializer = "glorot_uniform"
self.output_layer = keras.layers.Dense(output_dim, kernel_initializer=output_initializer)

def call(self, inputs, **kwargs):

out = self.model(inputs, **kwargs)
def call(self, inputs, training=False):
out = self.res_blocks(inputs, training=training)
return self.output_layer(out)

0 comments on commit 525e7de

Please sign in to comment.