diff --git a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py index 92fd8a980..717536161 100644 --- a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py @@ -6,8 +6,6 @@ from bayesflow.experimental.simulation import Distribution, find_distribution from bayesflow.experimental.types import Shape, Tensor from .couplings import AllInOneCoupling -from .subnets import find_subnet -from .transforms import find_transform class CouplingFlow(keras.Sequential): @@ -21,7 +19,7 @@ def all_in_one( cls, target_dim: int, num_layers: int, - subnet="default", + subnet_builder="default", transform="affine", permutation="fixed", act_norm=True, @@ -30,13 +28,11 @@ def all_in_one( ) -> "CouplingFlow": """ Construct a uniform coupling flow, consisting of dual couplings with a single type of transform. """ - subnet = find_subnet(subnet, transform, target_dim, **kwargs.pop("subnet_kwargs", {})) - transform = find_transform(transform) base_distribution = find_distribution(base_distribution, shape=(target_dim,)) couplings = [] for _ in range(num_layers): - layer = AllInOneCoupling(subnet, target_dim, transform, permutation, act_norm) + layer = AllInOneCoupling(subnet_builder, target_dim, transform, permutation, act_norm, **kwargs) couplings.append(layer) return cls(couplings, base_distribution) diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/all_in_one_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/all_in_one_coupling.py index 6737d22c8..341e4515d 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/all_in_one_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/all_in_one_coupling.py @@ -20,14 +20,15 @@ class AllInOneCoupling(keras.Layer): def __init__( self, - subnet: keras.Model | keras.layers.Layer, + subnet_builder: str, target_dim: int, - transform: Transform, + transform: str, permutation: str, act_norm: bool, + **kwargs ): super().__init__() - self.dual_coupling = DualCoupling(subnet, target_dim, transform) + self.dual_coupling = DualCoupling(subnet_builder, target_dim, transform, **kwargs) if permutation == "fixed": self.permutation = FixedPermutation.swap(target_dim) diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/coupling.py index 51d081438..a17f82af1 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/coupling.py @@ -1,23 +1,30 @@ import keras +from keras import ops -from ..transforms import Transform +from ..transforms import Transform, find_transform +from ..subnets import find_subnet class Coupling(keras.Layer): - """ Implements a single coupling layer, followed by a permutation. """ + """ Implements a single coupling layer that transforms half of its input through a coupling transform.""" def __init__( self, - subnet: keras.Model | keras.layers.Layer, - target_dim: int, - transform: Transform, + subnet_builder: str, + half_dim: int, + transform: str, **kwargs ): super().__init__(**kwargs) - self.dim = target_dim - self.subnet = subnet - self.transform = transform + self.half_dim = half_dim + self.subnet = find_subnet( + subnet=subnet_builder, + transform=transform, + output_dim=half_dim, + **kwargs.pop('subnet_kwargs', {}) + ) + self.transform = find_transform(transform) def call(self, x, c=None, forward=True, **kwargs): if forward: @@ -26,26 +33,27 @@ def call(self, x, c=None, forward=True, **kwargs): def forward(self, x, c=None, **kwargs): - x1, x2 = keras.ops.split(x, 2, axis=-1) - z1 = x1 - parameters = self.get_parameters(x1, c, **kwargs) - z2, log_det = self.transform.forward(x2, parameters) - z = keras.ops.concatenate([z1, z2], axis=-1) + x1, x2 = x[..., :self.half_dim], x[..., self.half_dim:] + z2 = x2 + parameters = self.get_parameters(x2, c, **kwargs) + z1, log_det = self.transform.forward(x1, parameters) + z = ops.concatenate([z1, z2], axis=-1) return z, log_det def inverse(self, z, c=None): - z1, z2 = keras.ops.split(z, 2, axis=-1) - x1 = z1 - parameters = self.get_parameters(x1, c) - x2, log_det = self.transform.inverse(z2, parameters) - x = keras.ops.concatenate([x1, x2], axis=-1) + z1, z2 = z[..., :self.half_dim], z[..., self.half_dim:] + x2 = z2 + parameters = self.get_parameters(x2, c) + x1, log_det = self.transform.inverse(z1, parameters) + x = ops.concatenate([x1, x2], axis=-1) return x, log_det def get_parameters(self, x, c=None, **kwargs): if c is not None: - x = keras.ops.concatenate([x, c], axis=-1) + x = ops.concatenate([x, c], axis=-1) parameters = self.subnet(x, **kwargs) + parameters = self.transform.split_parameters(parameters) parameters = self.transform.constrain_parameters(parameters) return parameters diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py index a83f8ef27..06c332a09 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py @@ -5,27 +5,29 @@ from bayesflow.experimental.types import Tensor from .coupling import Coupling -from ..transforms import Transform class DualCoupling(keras.Layer): def __init__( self, - subnet: keras.Model | keras.layers.Layer, + subnet_builder: str, target_dim: int, - transform: Transform + transform: str, + **kwargs ): super().__init__() self.coupling1 = Coupling( - subnet=subnet, - target_dim=math.floor(target_dim / 2), + subnet_builder=subnet_builder, + half_dim=math.floor(target_dim / 2), transform=transform, + **kwargs ) self.coupling2 = Coupling( - subnet=subnet, - target_dim=math.ceil(target_dim / 2), + subnet_builder=subnet_builder, + half_dim=math.ceil(target_dim / 2), transform=transform, + **kwargs ) def call(self, x: Tensor, c=None, forward=True, **kwargs) -> (Tensor, Tensor): diff --git a/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py b/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py index d472a26d4..bcb4a2929 100644 --- a/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py +++ b/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py @@ -7,7 +7,7 @@ from ...resnet.residual_block import ConditionalResidualBlock -def find_subnet(subnet: str | Callable, transform: str, target_dim: int, **kwargs): +def find_subnet(subnet: str | Callable, transform: str, output_dim: int, **kwargs): match subnet: case str() as name: @@ -25,7 +25,7 @@ def find_subnet(subnet: str | Callable, transform: str, target_dim: int, **kwarg case str() as name: match name.lower(): case "affine": - output_dim = target_dim * 2 + output_dim = output_dim * 2 case other: raise NotImplementedError(f"Unsupported transform name: '{other}'.") case other: diff --git a/bayesflow/experimental/networks/resnet/hidden_block.py b/bayesflow/experimental/networks/resnet/hidden_block.py index 86e95ea1f..c33bc94a3 100644 --- a/bayesflow/experimental/networks/resnet/hidden_block.py +++ b/bayesflow/experimental/networks/resnet/hidden_block.py @@ -2,7 +2,7 @@ import keras -class ConfigurableHiddenBlock(keras.Layer): +class ConfigurableHiddenBlock(keras.layers.Layer): def __init__( self, num_units, @@ -16,15 +16,26 @@ def __init__( self.activation_fn = keras.activations.get(activation) self.residual = residual + self.spectral_norm = spectral_norm self.dense_with_dropout = keras.Sequential() + if spectral_norm: self.dense_with_dropout.add(keras.layers.SpectralNormalization(keras.layers.Dense(num_units))) else: self.dense_with_dropout.add(keras.layers.Dense(num_units)) self.dense_with_dropout.add(keras.layers.Dropout(dropout_rate)) - def call(self, inputs, **kwargs): - x = self.dense_with_dropout(inputs, **kwargs) + def call(self, inputs, training=False): + x = self.dense_with_dropout(inputs, training=training) if self.residual: x = x + inputs return self.activation_fn(x) + + def get_config(self): + config = super().get_config() + config.update({ + "units": self.activation_fn, + "residual": self.residual, + "spectral_norm": self.spectral_norm + }) + return config diff --git a/bayesflow/experimental/networks/resnet/residual_block.py b/bayesflow/experimental/networks/resnet/residual_block.py index 5b8a6998c..cb7ade40f 100644 --- a/bayesflow/experimental/networks/resnet/residual_block.py +++ b/bayesflow/experimental/networks/resnet/residual_block.py @@ -4,7 +4,7 @@ from .hidden_block import ConfigurableHiddenBlock -class ConditionalResidualBlock(keras.Layer): +class ConditionalResidualBlock(keras.layers.Layer): """ Implements a simple configurable MLP with optional residual connections and dropout. @@ -22,7 +22,7 @@ def __init__( spectral_norm=False, dropout_rate=0.05, zero_output_init=True, - **kwargs, + **kwargs ): """ Creates an instance of a flexible and simple MLP with optional residual connections and dropout. @@ -31,7 +31,7 @@ def __init__( ----------- output_dim : int The output dimensionality, needs to be specified according to the model's function. - hidden_dim : int, optional, default: 512 + hidden_dim : int, optional, default: 256 The dimensionality of the hidden layers num_hidden : int, optional, default: 2 The number of hidden layers (minimum: 1) @@ -39,7 +39,7 @@ def __init__( The activation function of the dense layers residual : bool, optional, default: True Use residual connections in the MLP. - spectral_norm : bool, optional, default: True + spectral_norm : bool, optional, default: False Use spectral normalization for the network weights, which can make the learned function smoother and hence more robust to perturbations. dropout_rate : float, optional, default: 0.05 @@ -52,9 +52,11 @@ def __init__( super().__init__(**kwargs) self.dim = output_dim - self.model = keras.Sequential() + self.res_blocks = keras.Sequential( + [keras.layers.Dense(hidden_dim, activation=activation), keras.layers.Dropout(dropout_rate)] + ) for _ in range(num_hidden): - self.model.add( + self.res_blocks.add( ConfigurableHiddenBlock( num_units=hidden_dim, activation=activation, @@ -69,7 +71,6 @@ def __init__( output_initializer = "glorot_uniform" self.output_layer = keras.layers.Dense(output_dim, kernel_initializer=output_initializer) - def call(self, inputs, **kwargs): - - out = self.model(inputs, **kwargs) + def call(self, inputs, training=False): + out = self.res_blocks(inputs, training=training) return self.output_layer(out)