Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Allow serialization of custom networks #284

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type


from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -88,6 +88,27 @@ def __init__(

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"total_steps": total_steps,
"max_time": max_time,
"sigma2": sigma2,
"eps": eps,
"s0": s0,
"s1": s1,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _schedule_discretization(self, step) -> float:
"""Schedule function for adjusting the discretization level `N` during
the course of training.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import jvp, concatenate, find_network, keras_kwargs, expand_right_as, expand_right_to
from bayesflow.utils import (
jvp,
concatenate,
find_network,
keras_kwargs,
expand_right_as,
expand_right_to,
serialize_value_or_type,
deserialize_value_or_type,
)


from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -62,6 +71,22 @@ def __init__(

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"sigma_data": sigma_data,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs):
t = np.linspace(0.0, np.pi / 2, num_steps)
times = np.exp((t - np.pi / 2) * rho) * np.pi / 2
Expand Down
22 changes: 21 additions & 1 deletion bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_permutation, keras_kwargs
from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type

from .actnorm import ActNorm
from .couplings import DualCoupling
Expand Down Expand Up @@ -58,13 +58,33 @@ def __init__(

self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs.get("coupling_kwargs", {})))

# serialization: store all parameters necessary to call __init__
self.config = {
"depth": depth,
"transform": transform,
"permutation": permutation,
"use_actnorm": use_actnorm,
"base_distribution": base_distribution,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)

for layer in self.invertible_layers:
layer.build(xz_shape=xz_shape, conditions_shape=conditions_shape)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
Expand Down
18 changes: 17 additions & 1 deletion bayesflow/networks/coupling_flow/couplings/dual_coupling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.utils import keras_kwargs
from bayesflow.utils import keras_kwargs, serialize_value_or_type, deserialize_value_or_type
from bayesflow.types import Tensor
from .single_coupling import SingleCoupling
from ..invertible_layer import InvertibleLayer
Expand All @@ -15,6 +15,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
self.coupling2 = SingleCoupling(subnet, transform, **kwargs)
self.pivot = None

# serialization: store all parameters necessary to call __init__
self.config = {
"transform": transform,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
self.pivot = xz_shape[-1] // 2
Expand Down
18 changes: 17 additions & 1 deletion bayesflow/networks/coupling_flow/couplings/single_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
from ..invertible_layer import InvertibleLayer
from ..transforms import find_transform

Expand All @@ -26,6 +26,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
output_projector_kwargs.setdefault("kernel_initializer", "zeros")
self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs)

# serialization: store all parameters necessary to call __init__
self.config = {
"transform": transform,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

# noinspection PyMethodOverriding
def build(self, x1_shape, x2_shape, conditions_shape=None):
self.output_projector.units = self.transform.params_per_dim * x2_shape[-1]
Expand Down
27 changes: 26 additions & 1 deletion bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import expand_right_as, keras_kwargs, optimal_transport
from bayesflow.utils import (
expand_right_as,
keras_kwargs,
optimal_transport,
serialize_value_or_type,
deserialize_value_or_type,
)
from ..inference_network import InferenceNetwork
from .integrators import EulerIntegrator
from .integrators import RK2Integrator
Expand Down Expand Up @@ -52,10 +58,29 @@ def __init__(
case _:
raise NotImplementedError(f"No support for {integrator} integration")

# serialization: store all parameters necessary to call __init__
self.config = {
"base_distribution": base_distribution,
"integrator": integrator,
"use_optimal_transport": use_optimal_transport,
"optimal_transport_kwargs": optimal_transport_kwargs,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
super().build(xz_shape)
self.integrator.build(xz_shape, conditions_shape)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
Expand Down
31 changes: 30 additions & 1 deletion bayesflow/networks/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp
from bayesflow.utils import (
find_network,
keras_kwargs,
concatenate,
log_jacobian_determinant,
jvp,
vjp,
serialize_value_or_type,
deserialize_value_or_type,
)

from ..inference_network import InferenceNetwork

Expand Down Expand Up @@ -63,6 +72,26 @@ def __init__(

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"beta": beta,
"base_distribution": base_distribution,
"hutchinson_sampling": hutchinson_sampling,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet)
self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "encoder_subnet")
config = deserialize_value_or_type(config, "decoder_subnet")
return cls(**config)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
Expand Down
1 change: 1 addition & 0 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
format_bytes,
parse_bytes,
)
from .serialization import serialize_value_or_type, deserialize_value_or_type
from .jacobian_trace import jacobian_trace
from .jacobian import compute_jacobian, log_jacobian_determinant
from .jvp import jvp
Expand Down
78 changes: 78 additions & 0 deletions bayesflow/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import keras


PREFIX = "_bayesflow_"


def serialize_value_or_type(config, name, obj):
"""Serialize an object that can be either a value or a type
and add it to a copy of the supplied dictionary.

Parameters
----------
config : dict
Dictionary to add the serialized object to. This function does not
modify the dictionary in place, but returns a modified copy.
name : str
Name of the obj that should be stored. Required for later deserialization.
obj : object or type
The object to serialize. If `obj` is of type `type`, we use
`keras.saving.get_registered_name` to obtain the registered type name.
If it is not a type, we try to serialize it as a Keras object.

Returns
-------
updated_config : dict
Updated dictionary with a new key `"_bayesflow_<name>_type"` or
`"_bayesflow_<name>_val"`. The prefix is used to avoid name collisions,
the suffix indicates how the stored value has to be deserialized.

Notes
-----
We allow strings or `type` parameters at several places to instantiate objects
of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot
be serialized, we have to distinguish the two cases for serialization and
deserialization. This function is a helper function to standardize and
simplify this.
"""
updated_config = config.copy()
if isinstance(obj, type):
updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj)
else:
updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj)
return updated_config


def deserialize_value_or_type(config, name):
"""Deserialize an object that can be either a value or a type and add
it to the supplied dictionary.

Parameters
----------
config : dict
Dictionary containing the object to deserialize. If a type was
serialized, it should contain the key `"_bayesflow_<name>_type"`.
If an object was serialized, it should contain the key
`"_bayesflow_<name>_val"`. In a copy of this dictionary,
the item will be replaced with the key `name`.
name : str
Name of the object to deserialize.

Returns
-------
updated_config : dict
Updated dictionary with a new key `name`, with a value that is either
a type or an object.

See Also
--------
`serialize_value_or_type`
"""
updated_config = config.copy()
if f"{PREFIX}{name}_type" in config:
updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"])
del updated_config[f"{PREFIX}{name}_type"]
elif f"{PREFIX}{name}_val" in config:
updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"])
del updated_config[f"{PREFIX}{name}_val"]
return updated_config
Loading
Loading