From f9355ad603dad928f27eb695a9e3b66ba2f0d80b Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 13 Dec 2024 11:11:11 +0000 Subject: [PATCH] feat: allow serialization of custom networks This commit adds utility functions and extends existing networks to enable serialization of complete networks when custom network types are passed as arguments (e.g., for sub-networks in coupling flows). The main complications were: * Objects of type `type` (uninstantiated classes) cannot be serialized using `keras.saving.serialize_keras_object`, as the have no `get_config` function. * We want to support both strings and types as parameters, leading to the need to distinguish those during manual serialization/deserialization. * Auto-discovery of __init__ parameters is only active when `get_config` is not overridden, necessitating to manually store the configuration for serialization. For storing the types, we use `keras.saving.get_registered_name`, which can be reconstructed at deserialization using `keras.saving.get_registered_object`. Handling the different cases is moved the utility functions `(de)serialize_val_or_type`, which uses a naming scheme to determine which deserialization method to use. The same setup can be extended to other custom types, e.g. distributions. --- .../consistency_models/consistency_model.py | 23 +++++- .../continuous_consistency_model.py | 27 ++++++- .../networks/coupling_flow/coupling_flow.py | 22 +++++- .../coupling_flow/couplings/dual_coupling.py | 18 ++++- .../couplings/single_coupling.py | 18 ++++- .../networks/flow_matching/flow_matching.py | 27 ++++++- .../networks/free_form_flow/free_form_flow.py | 31 +++++++- bayesflow/utils/__init__.py | 1 + bayesflow/utils/serialization.py | 78 +++++++++++++++++++ tests/test_networks/conftest.py | 48 +++++++++++- .../test_networks/test_inference_networks.py | 14 ++-- 11 files changed, 290 insertions(+), 17 deletions(-) create mode 100644 bayesflow/utils/serialization.py diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 3ba275162..513d6a933 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -7,7 +7,7 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs +from bayesflow.utils import find_network, keras_kwargs, serialize_val_or_type, deserialize_val_or_type from ..inference_network import InferenceNetwork @@ -88,6 +88,27 @@ def __init__( self.seed_generator = keras.random.SeedGenerator() + # serialization: store all parameters necessary to call __init__ + self.config = { + "total_steps": total_steps, + "max_time": max_time, + "sigma2": sigma2, + "eps": eps, + "s0": s0, + "s1": s1, + **kwargs, + } + self.config = serialize_val_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_val_or_type(config, "subnet") + return cls(**config) + def _schedule_discretization(self, step) -> float: """Schedule function for adjusting the discretization level `N` during the course of training. diff --git a/bayesflow/networks/consistency_models/continuous_consistency_model.py b/bayesflow/networks/consistency_models/continuous_consistency_model.py index 2dc319782..251d99085 100644 --- a/bayesflow/networks/consistency_models/continuous_consistency_model.py +++ b/bayesflow/networks/consistency_models/continuous_consistency_model.py @@ -7,7 +7,16 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import jvp, concatenate, find_network, keras_kwargs, expand_right_as, expand_right_to +from bayesflow.utils import ( + jvp, + concatenate, + find_network, + keras_kwargs, + expand_right_as, + expand_right_to, + serialize_val_or_type, + deserialize_val_or_type, +) from ..inference_network import InferenceNetwork @@ -62,6 +71,22 @@ def __init__( self.seed_generator = keras.random.SeedGenerator() + # serialization: store all parameters necessary to call __init__ + self.config = { + "sigma_data": sigma_data, + **kwargs, + } + self.config = serialize_val_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_val_or_type(config, "subnet") + return cls(**config) + def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs): t = np.linspace(0.0, np.pi / 2, num_steps) times = np.exp((t - np.pi / 2) * rho) * np.pi / 2 diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index c5be3bc69..9a935b72d 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -2,7 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_permutation, keras_kwargs +from bayesflow.utils import find_permutation, keras_kwargs, serialize_val_or_type, deserialize_val_or_type from .actnorm import ActNorm from .couplings import DualCoupling @@ -58,6 +58,17 @@ def __init__( self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs.get("coupling_kwargs", {}))) + # serialization: store all parameters necessary to call __init__ + self.config = { + "depth": depth, + "transform": transform, + "permutation": permutation, + "use_actnorm": use_actnorm, + "base_distribution": base_distribution, + **kwargs, + } + self.config = serialize_val_or_type(self.config, "subnet", subnet) + # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) @@ -65,6 +76,15 @@ def build(self, xz_shape, conditions_shape=None): for layer in self.invertible_layers: layer.build(xz_shape=xz_shape, conditions_shape=conditions_shape) + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_val_or_type(config, "subnet") + return cls(**config) + def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: diff --git a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py index d66d517bc..b9017b069 100644 --- a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py @@ -1,7 +1,7 @@ import keras from keras.saving import register_keras_serializable as serializable -from bayesflow.utils import keras_kwargs +from bayesflow.utils import keras_kwargs, serialize_val_or_type, deserialize_val_or_type from bayesflow.types import Tensor from .single_coupling import SingleCoupling from ..invertible_layer import InvertibleLayer @@ -15,6 +15,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar self.coupling2 = SingleCoupling(subnet, transform, **kwargs) self.pivot = None + # serialization: store all parameters necessary to call __init__ + self.config = { + "transform": transform, + **kwargs, + } + self.config = serialize_val_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_val_or_type(config, "subnet") + return cls(**config) + # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): self.pivot = xz_shape[-1] // 2 diff --git a/bayesflow/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/networks/coupling_flow/couplings/single_coupling.py index 694bbed2a..eee4ff9dd 100644 --- a/bayesflow/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/single_coupling.py @@ -3,7 +3,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs +from bayesflow.utils import find_network, keras_kwargs, serialize_val_or_type, deserialize_val_or_type from ..invertible_layer import InvertibleLayer from ..transforms import find_transform @@ -26,6 +26,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar output_projector_kwargs.setdefault("kernel_initializer", "zeros") self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs) + # serialization: store all parameters necessary to call __init__ + self.config = { + "transform": transform, + **kwargs, + } + self.config = serialize_val_or_type(self.config, "subnet", subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_val_or_type(config, "subnet") + return cls(**config) + # noinspection PyMethodOverriding def build(self, x1_shape, x2_shape, conditions_shape=None): self.output_projector.units = self.transform.params_per_dim * x2_shape[-1] diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index bab2c7aeb..beeaeff58 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -3,7 +3,13 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor -from bayesflow.utils import expand_right_as, keras_kwargs, optimal_transport +from bayesflow.utils import ( + expand_right_as, + keras_kwargs, + optimal_transport, + serialize_val_or_type, + deserialize_val_or_type, +) from ..inference_network import InferenceNetwork from .integrators import EulerIntegrator from .integrators import RK2Integrator @@ -52,10 +58,29 @@ def __init__( case _: raise NotImplementedError(f"No support for {integrator} integration") + # serialization: store all parameters necessary to call __init__ + self.config = { + "base_distribution": base_distribution, + "integrator": integrator, + "use_optimal_transport": use_optimal_transport, + "optimal_transport_kwargs": optimal_transport_kwargs, + **kwargs, + } + self.config = serialize_val_or_type(self.config, "subnet", subnet) + def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: super().build(xz_shape) self.integrator.build(xz_shape, conditions_shape) + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_val_or_type(config, "subnet") + return cls(**config) + def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index c893d7df8..db5272667 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -3,7 +3,16 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp +from bayesflow.utils import ( + find_network, + keras_kwargs, + concatenate, + log_jacobian_determinant, + jvp, + vjp, + serialize_val_or_type, + deserialize_val_or_type, +) from ..inference_network import InferenceNetwork @@ -63,6 +72,26 @@ def __init__( self.seed_generator = keras.random.SeedGenerator() + # serialization: store all parameters necessary to call __init__ + self.config = { + "beta": beta, + "base_distribution": base_distribution, + "hutchinson_sampling": hutchinson_sampling, + **kwargs, + } + self.config = serialize_val_or_type(self.config, "encoder_subnet", encoder_subnet) + self.config = serialize_val_or_type(self.config, "decoder_subnet", decoder_subnet) + + def get_config(self): + base_config = super().get_config() + return base_config | self.config + + @classmethod + def from_config(cls, config): + config = deserialize_val_or_type(config, "encoder_subnet") + config = deserialize_val_or_type(config, "decoder_subnet") + return cls(**config) + # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 322ff2158..976d9b774 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -18,6 +18,7 @@ format_bytes, parse_bytes, ) +from .serialization import serialize_val_or_type, deserialize_val_or_type from .jacobian_trace import jacobian_trace from .jacobian import compute_jacobian, log_jacobian_determinant from .jvp import jvp diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py new file mode 100644 index 000000000..32f2406be --- /dev/null +++ b/bayesflow/utils/serialization.py @@ -0,0 +1,78 @@ +import keras + + +PREFIX = "_bayesflow_" + + +def serialize_val_or_type(config, name, obj): + """Serialize an object that can be either a value or a type + and add it to a copy of the supplied dictionary. + + Parameters + ---------- + config : dict + Dictionary to add the serialized object to. This function does not + modify the dictionary in place, but returns a modified copy. + name : str + Name of the obj that should be stored. Required for later deserialization. + obj : object or type + The object to serialize. If `obj` is of type `type`, we use + `keras.saving.get_registered_name` to obtain the registered type name. + If it is not a type, we try to serialize it as a Keras object. + + Returns + ------- + updated_config : dict + Updated dictionary with a new key `"_bayesflow__type"` or + `"_bayesflow__val"`. The prefix is used to avoid name collisions, + the suffix indicates how the stored value has to be deserialized. + + Notes + ----- + We allow strings or `type` parameters at several places to instantiate objects + of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot + be serialized, we have to distinguish the two cases for serialization and + deserialization. This function is a helper function to standardize and + simplify this. + """ + updated_config = config.copy() + if isinstance(obj, type): + updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj) + else: + updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj) + return updated_config + + +def deserialize_val_or_type(config, name): + """Deserialize an object that can be either a value or a type and add + it to the supplied dictionary. + + Parameters + ---------- + config : dict + Dictionary containing the object to deserialize. If a type was + serialized, it should contain the key `"_bayesflow__type"`. + If an object was serialized, it should contain the key + `"_bayesflow__val"`. In a copy of this dictionary, + the item will be replaced with the key `name`. + name : str + Name of the object to deserialize. + + Returns + ------- + updated_config : dict + Updated dictionary with a new key `name`, with a value that is either + a type or an object. + + See Also + -------- + `serialize_val_or_type` + """ + updated_config = config.copy() + if f"{PREFIX}{name}_type" in config: + updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"]) + del updated_config[f"{PREFIX}{name}_type"] + elif f"{PREFIX}{name}_val" in config: + updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"]) + del updated_config[f"{PREFIX}{name}_val"] + return updated_config diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 62796f11b..d8eac634c 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -8,6 +8,36 @@ def deep_set(): return DeepSet() +# For the serialization tests, we want to test passing str and type. +# For all other tests, this is not necessary and would double test time. +# Therefore, below we specify two variants of each network, one without and +# one with a subnet parameter. The latter will only be used for the relevant +# tests. If there is a better way to set the params to a single value ("mlp") +# for a given test, maybe this can be simplified, but I did not see one. +@pytest.fixture(params=["str", "type"], scope="function") +def subnet(request): + if request.param == "str": + return "mlp" + + from bayesflow.networks import MLP + + return MLP + + +@pytest.fixture() +def flow_matching(): + from bayesflow.networks import FlowMatching + + return FlowMatching() + + +@pytest.fixture() +def flow_matching_subnet(subnet): + from bayesflow.networks import FlowMatching + + return FlowMatching(subnet=subnet) + + @pytest.fixture() def coupling_flow(): from bayesflow.networks import CouplingFlow @@ -16,10 +46,10 @@ def coupling_flow(): @pytest.fixture() -def flow_matching(): - from bayesflow.networks import FlowMatching +def coupling_flow_subnet(subnet): + from bayesflow.networks import CouplingFlow - return FlowMatching() + return CouplingFlow(subnet=subnet) @pytest.fixture() @@ -29,11 +59,23 @@ def free_form_flow(): return FreeFormFlow() +@pytest.fixture() +def free_form_flow_subnet(subnet): + from bayesflow.networks import FreeFormFlow + + return FreeFormFlow(encoder_subnet=subnet, decoder_subnet=subnet) + + @pytest.fixture(params=["coupling_flow", "flow_matching", "free_form_flow"], scope="function") def inference_network(request): return request.getfixturevalue(request.param) +@pytest.fixture(params=["coupling_flow_subnet", "flow_matching_subnet", "free_form_flow_subnet"], scope="function") +def inference_network_subnet(request): + return request.getfixturevalue(request.param) + + @pytest.fixture() def lst_net(): from bayesflow.networks import LSTNet diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index cdc4e9e09..33395cdc5 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -125,22 +125,22 @@ def f(x): assert allclose(inverse_log_density, numerical_inverse_log_density, rtol=1e-4, atol=1e-5) -def test_serialize_deserialize(inference_network, random_samples, random_conditions): +def test_serialize_deserialize(inference_network_subnet, subnet, random_samples, random_conditions): # to save, the model must be built - inference_network(random_samples, conditions=random_conditions) + inference_network_subnet(random_samples, conditions=random_conditions) - serialized = serialize(inference_network) + serialized = serialize(inference_network_subnet) deserialized = deserialize(serialized) reserialized = serialize(deserialized) assert serialized == reserialized -def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions): +def test_save_and_load(tmp_path, inference_network_subnet, subnet, random_samples, random_conditions): # to save, the model must be built - inference_network(random_samples, conditions=random_conditions) + inference_network_subnet(random_samples, conditions=random_conditions) - keras.saving.save_model(inference_network, tmp_path / "model.keras") + keras.saving.save_model(inference_network_subnet, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") - assert_layers_equal(inference_network, loaded) + assert_layers_equal(inference_network_subnet, loaded)