Skip to content

Commit

Permalink
Flexible Default Settings Across Inference Networks (#296)
Browse files Browse the repository at this point in the history
* Add default configs

* Correct configs and nice defaults

* use small network and fixed step integration for flow matching tests

---------

Co-authored-by: larskue <[email protected]>
  • Loading branch information
stefanradev93 and LarsKue authored Feb 14, 2025
1 parent a2f145d commit 22521c7
Show file tree
Hide file tree
Showing 7 changed files with 408 additions and 230 deletions.
20 changes: 18 additions & 2 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class ConsistencyModel(InferenceNetwork):
Discussion: https://openreview.net/forum?id=WNzy9bRDvG
"""

MLP_DEFAULT_CONFIG = {
"widths": (256, 256, 256, 256, 256),
"activation": "mish",
"kernel_initializer": "he_normal",
"residual": True,
"dropout": 0.05,
"spectral_normalization": False,
}

def __init__(
self,
total_steps: int | float,
Expand Down Expand Up @@ -65,12 +74,18 @@ def __init__(
**kwargs : dict, optional, default: {}
Additional keyword arguments
"""
# Normal is the only supported base distribution for CMs
super().__init__(base_distribution="normal", **keras_kwargs(kwargs))

self.total_steps = float(total_steps)

self.student = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
if subnet == "mlp":
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG.copy()
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
else:
subnet_kwargs = kwargs.get("subnet_kwargs", {})

self.student = find_network(subnet, **subnet_kwargs)

self.student_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")

self.sigma2 = ops.convert_to_tensor(sigma2)
Expand All @@ -82,6 +97,7 @@ def __init__(

self.s0 = float(s0)
self.s1 = float(s1)

# create variable that works with JIT compilation
self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int")
self.current_step.assign(0)
Expand Down
17 changes: 16 additions & 1 deletion bayesflow/networks/coupling_flow/couplings/single_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,25 @@ class SingleCoupling(InvertibleLayer):
Subnet output tensors are linearly mapped to the correct dimension.
"""

MLP_DEFAULT_CONFIG = {
"widths": (128, 128),
"activation": "hard_silu",
"kernel_initializer": "glorot_uniform",
"residual": False,
"dropout": 0.05,
"spectral_normalization": False,
}

def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwargs):
super().__init__(**keras_kwargs(kwargs))

self.network = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
if subnet == "mlp":
subnet_kwargs = SingleCoupling.MLP_DEFAULT_CONFIG.copy()
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
else:
subnet_kwargs = kwargs.get("subnet_kwargs", {})

self.network = find_network(subnet, **subnet_kwargs)
self.transform = find_transform(transform, **kwargs.get("transform_kwargs", {}))

output_projector_kwargs = kwargs.get("output_projector_kwargs", {})
Expand Down
62 changes: 37 additions & 25 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,31 @@ class FlowMatching(InferenceNetwork):
[3] Optimal Transport Flow Matching: arXiv:2302.00482
"""

MLP_DEFAULT_CONFIG = {
"widths": (256, 256, 256, 256, 256),
"activation": "mish",
"kernel_initializer": "he_normal",
"residual": True,
"dropout": 0.05,
"spectral_normalization": False,
}

OPTIMAL_TRANSPORT_DEFAULT_CONFIG = {
"method": "sinkhorn",
"cost": "euclidean",
"regularization": 0.1,
"max_steps": 100,
"tolerance": 1e-4,
}

INTEGRATE_DEFAULT_CONFIG = {
"method": "rk45",
"steps": "adaptive",
"tolerance": 1e-3,
"min_steps": 10,
"max_steps": 100,
}

def __init__(
self,
subnet: str | type = "mlp",
Expand All @@ -41,41 +66,28 @@ def __init__(

self.use_optimal_transport = use_optimal_transport

if integrate_kwargs is None:
integrate_kwargs = {
"method": "rk45",
"steps": "adaptive",
"tolerance": 1e-3,
"min_steps": 10,
"max_steps": 100,
}

self.integrate_kwargs = integrate_kwargs

if optimal_transport_kwargs is None:
optimal_transport_kwargs = {
"method": "sinkhorn",
"cost": "euclidean",
"regularization": 0.1,
"max_steps": 100,
"tolerance": 1e-4,
}
self.integrate_kwargs = integrate_kwargs or FlowMatching.INTEGRATE_DEFAULT_CONFIG.copy()
self.optimal_transport_kwargs = optimal_transport_kwargs or FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG.copy()

self.loss_fn = keras.losses.get(loss_fn)

self.optimal_transport_kwargs = optimal_transport_kwargs

self.seed_generator = keras.random.SeedGenerator()

self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
if subnet == "mlp":
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG.copy()
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
else:
subnet_kwargs = kwargs.get("subnet_kwargs", {})

self.subnet = find_network(subnet, **subnet_kwargs)
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")

# serialization: store all parameters necessary to call __init__
self.config = {
"base_distribution": base_distribution,
"use_optimal_transport": use_optimal_transport,
"optimal_transport_kwargs": optimal_transport_kwargs,
"integrate_kwargs": integrate_kwargs,
"use_optimal_transport": self.use_optimal_transport,
"optimal_transport_kwargs": self.optimal_transport_kwargs,
"integrate_kwargs": self.integrate_kwargs,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)
Expand Down
36 changes: 34 additions & 2 deletions bayesflow/networks/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ class FreeFormFlow(InferenceNetwork):
In International Conference on Learning Representations.
"""

ENCODER_MLP_DEFAULT_CONFIG = {
"widths": (256, 256, 256, 256),
"activation": "mish",
"kernel_initializer": "he_normal",
"residual": True,
"dropout": 0.05,
"spectral_normalization": False,
}

DECODER_MLP_DEFAULT_CONFIG = {
"widths": (256, 256, 256, 256),
"activation": "mish",
"kernel_initializer": "he_normal",
"residual": True,
"dropout": 0.05,
"spectral_normalization": False,
}

def __init__(
self,
beta: float = 50.0,
Expand Down Expand Up @@ -62,9 +80,23 @@ def __init__(
Additional keyword arguments
"""
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {}))

if encoder_subnet == "mlp":
encoder_subnet_kwargs = FreeFormFlow.ENCODER_MLP_DEFAULT_CONFIG.copy()
encoder_subnet_kwargs.update(kwargs.get("encoder_subnet_kwargs", {}))
else:
encoder_subnet_kwargs = kwargs.get("encoder_subnet_kwargs", {})

self.encoder_subnet = find_network(encoder_subnet, **encoder_subnet_kwargs)
self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {}))

if decoder_subnet == "mlp":
decoder_subnet_kwargs = FreeFormFlow.DECODER_MLP_DEFAULT_CONFIG.copy()
decoder_subnet_kwargs.update(kwargs.get("decoder_subnet_kwargs", {}))
else:
decoder_subnet_kwargs = kwargs.get("decoder_subnet_kwargs", {})

self.decoder_subnet = find_network(decoder_subnet, **decoder_subnet_kwargs)
self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")

self.hutchinson_sampling = hutchinson_sampling
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class InferenceNetwork(keras.Layer):
MLP_DEFAULT_CONFIG = {}

def __init__(self, base_distribution: str = "normal", **kwargs):
super().__init__(**kwargs)
self.base_distribution = find_distribution(base_distribution)
Expand Down
Loading

0 comments on commit 22521c7

Please sign in to comment.