Skip to content

Commit 525e7de

Browse files
committed
Working version of couplings with residual internal networks
1 parent 7597e45 commit 525e7de

File tree

7 files changed

+68
-49
lines changed

7 files changed

+68
-49
lines changed

bayesflow/experimental/networks/coupling_flow/coupling_flow.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from bayesflow.experimental.simulation import Distribution, find_distribution
77
from bayesflow.experimental.types import Shape, Tensor
88
from .couplings import AllInOneCoupling
9-
from .subnets import find_subnet
10-
from .transforms import find_transform
119

1210

1311
class CouplingFlow(keras.Sequential):
@@ -21,7 +19,7 @@ def all_in_one(
2119
cls,
2220
target_dim: int,
2321
num_layers: int,
24-
subnet="default",
22+
subnet_builder="default",
2523
transform="affine",
2624
permutation="fixed",
2725
act_norm=True,
@@ -30,13 +28,11 @@ def all_in_one(
3028
) -> "CouplingFlow":
3129
""" Construct a uniform coupling flow, consisting of dual couplings with a single type of transform. """
3230

33-
subnet = find_subnet(subnet, transform, target_dim, **kwargs.pop("subnet_kwargs", {}))
34-
transform = find_transform(transform)
3531
base_distribution = find_distribution(base_distribution, shape=(target_dim,))
3632

3733
couplings = []
3834
for _ in range(num_layers):
39-
layer = AllInOneCoupling(subnet, target_dim, transform, permutation, act_norm)
35+
layer = AllInOneCoupling(subnet_builder, target_dim, transform, permutation, act_norm, **kwargs)
4036
couplings.append(layer)
4137

4238
return cls(couplings, base_distribution)

bayesflow/experimental/networks/coupling_flow/couplings/all_in_one_coupling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ class AllInOneCoupling(keras.Layer):
2020

2121
def __init__(
2222
self,
23-
subnet: keras.Model | keras.layers.Layer,
23+
subnet_builder: str,
2424
target_dim: int,
25-
transform: Transform,
25+
transform: str,
2626
permutation: str,
2727
act_norm: bool,
28+
**kwargs
2829
):
2930
super().__init__()
30-
self.dual_coupling = DualCoupling(subnet, target_dim, transform)
31+
self.dual_coupling = DualCoupling(subnet_builder, target_dim, transform, **kwargs)
3132

3233
if permutation == "fixed":
3334
self.permutation = FixedPermutation.swap(target_dim)
Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11

22
import keras
3+
from keras import ops
34

4-
from ..transforms import Transform
5+
from ..transforms import Transform, find_transform
6+
from ..subnets import find_subnet
57

68

79
class Coupling(keras.Layer):
8-
""" Implements a single coupling layer, followed by a permutation. """
10+
""" Implements a single coupling layer that transforms half of its input through a coupling transform."""
911
def __init__(
1012
self,
11-
subnet: keras.Model | keras.layers.Layer,
12-
target_dim: int,
13-
transform: Transform,
13+
subnet_builder: str,
14+
half_dim: int,
15+
transform: str,
1416
**kwargs
1517
):
1618
super().__init__(**kwargs)
1719

18-
self.dim = target_dim
19-
self.subnet = subnet
20-
self.transform = transform
20+
self.half_dim = half_dim
21+
self.subnet = find_subnet(
22+
subnet=subnet_builder,
23+
transform=transform,
24+
output_dim=half_dim,
25+
**kwargs.pop('subnet_kwargs', {})
26+
)
27+
self.transform = find_transform(transform)
2128

2229
def call(self, x, c=None, forward=True, **kwargs):
2330
if forward:
@@ -26,26 +33,27 @@ def call(self, x, c=None, forward=True, **kwargs):
2633

2734
def forward(self, x, c=None, **kwargs):
2835

29-
x1, x2 = keras.ops.split(x, 2, axis=-1)
30-
z1 = x1
31-
parameters = self.get_parameters(x1, c, **kwargs)
32-
z2, log_det = self.transform.forward(x2, parameters)
33-
z = keras.ops.concatenate([z1, z2], axis=-1)
36+
x1, x2 = x[..., :self.half_dim], x[..., self.half_dim:]
37+
z2 = x2
38+
parameters = self.get_parameters(x2, c, **kwargs)
39+
z1, log_det = self.transform.forward(x1, parameters)
40+
z = ops.concatenate([z1, z2], axis=-1)
3441
return z, log_det
3542

3643
def inverse(self, z, c=None):
37-
z1, z2 = keras.ops.split(z, 2, axis=-1)
38-
x1 = z1
39-
parameters = self.get_parameters(x1, c)
40-
x2, log_det = self.transform.inverse(z2, parameters)
41-
x = keras.ops.concatenate([x1, x2], axis=-1)
44+
z1, z2 = z[..., :self.half_dim], z[..., self.half_dim:]
45+
x2 = z2
46+
parameters = self.get_parameters(x2, c)
47+
x1, log_det = self.transform.inverse(z1, parameters)
48+
x = ops.concatenate([x1, x2], axis=-1)
4249
return x, log_det
4350

4451
def get_parameters(self, x, c=None, **kwargs):
4552
if c is not None:
46-
x = keras.ops.concatenate([x, c], axis=-1)
53+
x = ops.concatenate([x, c], axis=-1)
4754

4855
parameters = self.subnet(x, **kwargs)
56+
parameters = self.transform.split_parameters(parameters)
4957
parameters = self.transform.constrain_parameters(parameters)
5058

5159
return parameters

bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,29 @@
55

66
from bayesflow.experimental.types import Tensor
77
from .coupling import Coupling
8-
from ..transforms import Transform
98

109

1110
class DualCoupling(keras.Layer):
1211
def __init__(
1312
self,
14-
subnet: keras.Model | keras.layers.Layer,
13+
subnet_builder: str,
1514
target_dim: int,
16-
transform: Transform
15+
transform: str,
16+
**kwargs
1717
):
1818
super().__init__()
1919

2020
self.coupling1 = Coupling(
21-
subnet=subnet,
22-
target_dim=math.floor(target_dim / 2),
21+
subnet_builder=subnet_builder,
22+
half_dim=math.floor(target_dim / 2),
2323
transform=transform,
24+
**kwargs
2425
)
2526
self.coupling2 = Coupling(
26-
subnet=subnet,
27-
target_dim=math.ceil(target_dim / 2),
27+
subnet_builder=subnet_builder,
28+
half_dim=math.ceil(target_dim / 2),
2829
transform=transform,
30+
**kwargs
2931
)
3032

3133
def call(self, x: Tensor, c=None, forward=True, **kwargs) -> (Tensor, Tensor):

bayesflow/experimental/networks/coupling_flow/subnets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ...resnet.residual_block import ConditionalResidualBlock
88

99

10-
def find_subnet(subnet: str | Callable, transform: str, target_dim: int, **kwargs):
10+
def find_subnet(subnet: str | Callable, transform: str, output_dim: int, **kwargs):
1111

1212
match subnet:
1313
case str() as name:
@@ -25,7 +25,7 @@ def find_subnet(subnet: str | Callable, transform: str, target_dim: int, **kwarg
2525
case str() as name:
2626
match name.lower():
2727
case "affine":
28-
output_dim = target_dim * 2
28+
output_dim = output_dim * 2
2929
case other:
3030
raise NotImplementedError(f"Unsupported transform name: '{other}'.")
3131
case other:

bayesflow/experimental/networks/resnet/hidden_block.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import keras
33

44

5-
class ConfigurableHiddenBlock(keras.Layer):
5+
class ConfigurableHiddenBlock(keras.layers.Layer):
66
def __init__(
77
self,
88
num_units,
@@ -16,15 +16,26 @@ def __init__(
1616

1717
self.activation_fn = keras.activations.get(activation)
1818
self.residual = residual
19+
self.spectral_norm = spectral_norm
1920
self.dense_with_dropout = keras.Sequential()
21+
2022
if spectral_norm:
2123
self.dense_with_dropout.add(keras.layers.SpectralNormalization(keras.layers.Dense(num_units)))
2224
else:
2325
self.dense_with_dropout.add(keras.layers.Dense(num_units))
2426
self.dense_with_dropout.add(keras.layers.Dropout(dropout_rate))
2527

26-
def call(self, inputs, **kwargs):
27-
x = self.dense_with_dropout(inputs, **kwargs)
28+
def call(self, inputs, training=False):
29+
x = self.dense_with_dropout(inputs, training=training)
2830
if self.residual:
2931
x = x + inputs
3032
return self.activation_fn(x)
33+
34+
def get_config(self):
35+
config = super().get_config()
36+
config.update({
37+
"units": self.activation_fn,
38+
"residual": self.residual,
39+
"spectral_norm": self.spectral_norm
40+
})
41+
return config

bayesflow/experimental/networks/resnet/residual_block.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .hidden_block import ConfigurableHiddenBlock
55

66

7-
class ConditionalResidualBlock(keras.Layer):
7+
class ConditionalResidualBlock(keras.layers.Layer):
88
"""
99
Implements a simple configurable MLP with optional residual connections and dropout.
1010
@@ -22,7 +22,7 @@ def __init__(
2222
spectral_norm=False,
2323
dropout_rate=0.05,
2424
zero_output_init=True,
25-
**kwargs,
25+
**kwargs
2626
):
2727
"""
2828
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
@@ -31,15 +31,15 @@ def __init__(
3131
-----------
3232
output_dim : int
3333
The output dimensionality, needs to be specified according to the model's function.
34-
hidden_dim : int, optional, default: 512
34+
hidden_dim : int, optional, default: 256
3535
The dimensionality of the hidden layers
3636
num_hidden : int, optional, default: 2
3737
The number of hidden layers (minimum: 1)
3838
activation : string, optional, default: 'gelu'
3939
The activation function of the dense layers
4040
residual : bool, optional, default: True
4141
Use residual connections in the MLP.
42-
spectral_norm : bool, optional, default: True
42+
spectral_norm : bool, optional, default: False
4343
Use spectral normalization for the network weights, which can make
4444
the learned function smoother and hence more robust to perturbations.
4545
dropout_rate : float, optional, default: 0.05
@@ -52,9 +52,11 @@ def __init__(
5252
super().__init__(**kwargs)
5353

5454
self.dim = output_dim
55-
self.model = keras.Sequential()
55+
self.res_blocks = keras.Sequential(
56+
[keras.layers.Dense(hidden_dim, activation=activation), keras.layers.Dropout(dropout_rate)]
57+
)
5658
for _ in range(num_hidden):
57-
self.model.add(
59+
self.res_blocks.add(
5860
ConfigurableHiddenBlock(
5961
num_units=hidden_dim,
6062
activation=activation,
@@ -69,7 +71,6 @@ def __init__(
6971
output_initializer = "glorot_uniform"
7072
self.output_layer = keras.layers.Dense(output_dim, kernel_initializer=output_initializer)
7173

72-
def call(self, inputs, **kwargs):
73-
74-
out = self.model(inputs, **kwargs)
74+
def call(self, inputs, training=False):
75+
out = self.res_blocks(inputs, training=training)
7576
return self.output_layer(out)

0 commit comments

Comments
 (0)