From 8f876f0c7b9f64a575f00cabacec9692fd4d5410 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 20 Dec 2024 13:40:59 +0000 Subject: [PATCH] rename (de)serialize_val_or_type to (de)serialize_value_or_type --- .../networks/consistency_models/consistency_model.py | 6 +++--- .../continuous_consistency_model.py | 8 ++++---- bayesflow/networks/coupling_flow/coupling_flow.py | 6 +++--- .../coupling_flow/couplings/dual_coupling.py | 6 +++--- .../coupling_flow/couplings/single_coupling.py | 6 +++--- bayesflow/networks/flow_matching/flow_matching.py | 8 ++++---- bayesflow/networks/free_form_flow/free_form_flow.py | 12 ++++++------ bayesflow/utils/__init__.py | 2 +- bayesflow/utils/serialization.py | 6 +++--- 9 files changed, 30 insertions(+), 30 deletions(-) diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 513d6a933..03a27c6e6 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -7,7 +7,7 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, serialize_val_or_type, deserialize_val_or_type +from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type from ..inference_network import InferenceNetwork @@ -98,7 +98,7 @@ def __init__( "s1": s1, **kwargs, } - self.config = serialize_val_or_type(self.config, "subnet", subnet) + self.config = serialize_value_or_type(self.config, "subnet", subnet) def get_config(self): base_config = super().get_config() @@ -106,7 +106,7 @@ def get_config(self): @classmethod def from_config(cls, config): - config = deserialize_val_or_type(config, "subnet") + config = deserialize_value_or_type(config, "subnet") return cls(**config) def _schedule_discretization(self, step) -> float: diff --git a/bayesflow/networks/consistency_models/continuous_consistency_model.py b/bayesflow/networks/consistency_models/continuous_consistency_model.py index 251d99085..c459ab535 100644 --- a/bayesflow/networks/consistency_models/continuous_consistency_model.py +++ b/bayesflow/networks/consistency_models/continuous_consistency_model.py @@ -14,8 +14,8 @@ keras_kwargs, expand_right_as, expand_right_to, - serialize_val_or_type, - deserialize_val_or_type, + serialize_value_or_type, + deserialize_value_or_type, ) @@ -76,7 +76,7 @@ def __init__( "sigma_data": sigma_data, **kwargs, } - self.config = serialize_val_or_type(self.config, "subnet", subnet) + self.config = serialize_value_or_type(self.config, "subnet", subnet) def get_config(self): base_config = super().get_config() @@ -84,7 +84,7 @@ def get_config(self): @classmethod def from_config(cls, config): - config = deserialize_val_or_type(config, "subnet") + config = deserialize_value_or_type(config, "subnet") return cls(**config) def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs): diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index 9a935b72d..a357d52d8 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -2,7 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_permutation, keras_kwargs, serialize_val_or_type, deserialize_val_or_type +from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type from .actnorm import ActNorm from .couplings import DualCoupling @@ -67,7 +67,7 @@ def __init__( "base_distribution": base_distribution, **kwargs, } - self.config = serialize_val_or_type(self.config, "subnet", subnet) + self.config = serialize_value_or_type(self.config, "subnet", subnet) # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): @@ -82,7 +82,7 @@ def get_config(self): @classmethod def from_config(cls, config): - config = deserialize_val_or_type(config, "subnet") + config = deserialize_value_or_type(config, "subnet") return cls(**config) def _forward( diff --git a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py index b9017b069..a862ad952 100644 --- a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py @@ -1,7 +1,7 @@ import keras from keras.saving import register_keras_serializable as serializable -from bayesflow.utils import keras_kwargs, serialize_val_or_type, deserialize_val_or_type +from bayesflow.utils import keras_kwargs, serialize_value_or_type, deserialize_value_or_type from bayesflow.types import Tensor from .single_coupling import SingleCoupling from ..invertible_layer import InvertibleLayer @@ -20,7 +20,7 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar "transform": transform, **kwargs, } - self.config = serialize_val_or_type(self.config, "subnet", subnet) + self.config = serialize_value_or_type(self.config, "subnet", subnet) def get_config(self): base_config = super().get_config() @@ -28,7 +28,7 @@ def get_config(self): @classmethod def from_config(cls, config): - config = deserialize_val_or_type(config, "subnet") + config = deserialize_value_or_type(config, "subnet") return cls(**config) # noinspection PyMethodOverriding diff --git a/bayesflow/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/networks/coupling_flow/couplings/single_coupling.py index eee4ff9dd..f4fba7cb1 100644 --- a/bayesflow/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/single_coupling.py @@ -3,7 +3,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, serialize_val_or_type, deserialize_val_or_type +from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type from ..invertible_layer import InvertibleLayer from ..transforms import find_transform @@ -31,7 +31,7 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar "transform": transform, **kwargs, } - self.config = serialize_val_or_type(self.config, "subnet", subnet) + self.config = serialize_value_or_type(self.config, "subnet", subnet) def get_config(self): base_config = super().get_config() @@ -39,7 +39,7 @@ def get_config(self): @classmethod def from_config(cls, config): - config = deserialize_val_or_type(config, "subnet") + config = deserialize_value_or_type(config, "subnet") return cls(**config) # noinspection PyMethodOverriding diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index beeaeff58..118d3546f 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -7,8 +7,8 @@ expand_right_as, keras_kwargs, optimal_transport, - serialize_val_or_type, - deserialize_val_or_type, + serialize_value_or_type, + deserialize_value_or_type, ) from ..inference_network import InferenceNetwork from .integrators import EulerIntegrator @@ -66,7 +66,7 @@ def __init__( "optimal_transport_kwargs": optimal_transport_kwargs, **kwargs, } - self.config = serialize_val_or_type(self.config, "subnet", subnet) + self.config = serialize_value_or_type(self.config, "subnet", subnet) def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: super().build(xz_shape) @@ -78,7 +78,7 @@ def get_config(self): @classmethod def from_config(cls, config): - config = deserialize_val_or_type(config, "subnet") + config = deserialize_value_or_type(config, "subnet") return cls(**config) def _forward( diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index db5272667..23c375c1f 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -10,8 +10,8 @@ log_jacobian_determinant, jvp, vjp, - serialize_val_or_type, - deserialize_val_or_type, + serialize_value_or_type, + deserialize_value_or_type, ) from ..inference_network import InferenceNetwork @@ -79,8 +79,8 @@ def __init__( "hutchinson_sampling": hutchinson_sampling, **kwargs, } - self.config = serialize_val_or_type(self.config, "encoder_subnet", encoder_subnet) - self.config = serialize_val_or_type(self.config, "decoder_subnet", decoder_subnet) + self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet) + self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet) def get_config(self): base_config = super().get_config() @@ -88,8 +88,8 @@ def get_config(self): @classmethod def from_config(cls, config): - config = deserialize_val_or_type(config, "encoder_subnet") - config = deserialize_val_or_type(config, "decoder_subnet") + config = deserialize_value_or_type(config, "encoder_subnet") + config = deserialize_value_or_type(config, "decoder_subnet") return cls(**config) # noinspection PyMethodOverriding diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 562d97712..ec8f7fffb 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -18,7 +18,7 @@ format_bytes, parse_bytes, ) -from .serialization import serialize_val_or_type, deserialize_val_or_type +from .serialization import serialize_value_or_type, deserialize_value_or_type from .jacobian_trace import jacobian_trace from .jacobian import compute_jacobian, log_jacobian_determinant from .jvp import jvp diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index 32f2406be..3d621585c 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -4,7 +4,7 @@ PREFIX = "_bayesflow_" -def serialize_val_or_type(config, name, obj): +def serialize_value_or_type(config, name, obj): """Serialize an object that can be either a value or a type and add it to a copy of the supplied dictionary. @@ -43,7 +43,7 @@ def serialize_val_or_type(config, name, obj): return updated_config -def deserialize_val_or_type(config, name): +def deserialize_value_or_type(config, name): """Deserialize an object that can be either a value or a type and add it to the supplied dictionary. @@ -66,7 +66,7 @@ def deserialize_val_or_type(config, name): See Also -------- - `serialize_val_or_type` + `serialize_value_or_type` """ updated_config = config.copy() if f"{PREFIX}{name}_type" in config: