Skip to content

Commit

Permalink
rename (de)serialize_val_or_type to (de)serialize_value_or_type
Browse files Browse the repository at this point in the history
  • Loading branch information
vpratz committed Dec 20, 2024
1 parent 603f3e9 commit 8f876f0
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 30 deletions.
6 changes: 3 additions & 3 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, serialize_val_or_type, deserialize_val_or_type
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type


from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -98,15 +98,15 @@ def __init__(
"s1": s1,
**kwargs,
}
self.config = serialize_val_or_type(self.config, "subnet", subnet)
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_val_or_type(config, "subnet")
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _schedule_discretization(self, step) -> float:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
keras_kwargs,
expand_right_as,
expand_right_to,
serialize_val_or_type,
deserialize_val_or_type,
serialize_value_or_type,
deserialize_value_or_type,
)


Expand Down Expand Up @@ -76,15 +76,15 @@ def __init__(
"sigma_data": sigma_data,
**kwargs,
}
self.config = serialize_val_or_type(self.config, "subnet", subnet)
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_val_or_type(config, "subnet")
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_permutation, keras_kwargs, serialize_val_or_type, deserialize_val_or_type
from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type

from .actnorm import ActNorm
from .couplings import DualCoupling
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
"base_distribution": base_distribution,
**kwargs,
}
self.config = serialize_val_or_type(self.config, "subnet", subnet)
self.config = serialize_value_or_type(self.config, "subnet", subnet)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
Expand All @@ -82,7 +82,7 @@ def get_config(self):

@classmethod
def from_config(cls, config):
config = deserialize_val_or_type(config, "subnet")
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _forward(
Expand Down
6 changes: 3 additions & 3 deletions bayesflow/networks/coupling_flow/couplings/dual_coupling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.utils import keras_kwargs, serialize_val_or_type, deserialize_val_or_type
from bayesflow.utils import keras_kwargs, serialize_value_or_type, deserialize_value_or_type
from bayesflow.types import Tensor
from .single_coupling import SingleCoupling
from ..invertible_layer import InvertibleLayer
Expand All @@ -20,15 +20,15 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
"transform": transform,
**kwargs,
}
self.config = serialize_val_or_type(self.config, "subnet", subnet)
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_val_or_type(config, "subnet")
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

# noinspection PyMethodOverriding
Expand Down
6 changes: 3 additions & 3 deletions bayesflow/networks/coupling_flow/couplings/single_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, serialize_val_or_type, deserialize_val_or_type
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
from ..invertible_layer import InvertibleLayer
from ..transforms import find_transform

Expand Down Expand Up @@ -31,15 +31,15 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
"transform": transform,
**kwargs,
}
self.config = serialize_val_or_type(self.config, "subnet", subnet)
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_val_or_type(config, "subnet")
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

# noinspection PyMethodOverriding
Expand Down
8 changes: 4 additions & 4 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
expand_right_as,
keras_kwargs,
optimal_transport,
serialize_val_or_type,
deserialize_val_or_type,
serialize_value_or_type,
deserialize_value_or_type,
)
from ..inference_network import InferenceNetwork
from .integrators import EulerIntegrator
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
"optimal_transport_kwargs": optimal_transport_kwargs,
**kwargs,
}
self.config = serialize_val_or_type(self.config, "subnet", subnet)
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
super().build(xz_shape)
Expand All @@ -78,7 +78,7 @@ def get_config(self):

@classmethod
def from_config(cls, config):
config = deserialize_val_or_type(config, "subnet")
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _forward(
Expand Down
12 changes: 6 additions & 6 deletions bayesflow/networks/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
log_jacobian_determinant,
jvp,
vjp,
serialize_val_or_type,
deserialize_val_or_type,
serialize_value_or_type,
deserialize_value_or_type,
)

from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -79,17 +79,17 @@ def __init__(
"hutchinson_sampling": hutchinson_sampling,
**kwargs,
}
self.config = serialize_val_or_type(self.config, "encoder_subnet", encoder_subnet)
self.config = serialize_val_or_type(self.config, "decoder_subnet", decoder_subnet)
self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet)
self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_val_or_type(config, "encoder_subnet")
config = deserialize_val_or_type(config, "decoder_subnet")
config = deserialize_value_or_type(config, "encoder_subnet")
config = deserialize_value_or_type(config, "decoder_subnet")
return cls(**config)

# noinspection PyMethodOverriding
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
format_bytes,
parse_bytes,
)
from .serialization import serialize_val_or_type, deserialize_val_or_type
from .serialization import serialize_value_or_type, deserialize_value_or_type
from .jacobian_trace import jacobian_trace
from .jacobian import compute_jacobian, log_jacobian_determinant
from .jvp import jvp
Expand Down
6 changes: 3 additions & 3 deletions bayesflow/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
PREFIX = "_bayesflow_"


def serialize_val_or_type(config, name, obj):
def serialize_value_or_type(config, name, obj):
"""Serialize an object that can be either a value or a type
and add it to a copy of the supplied dictionary.
Expand Down Expand Up @@ -43,7 +43,7 @@ def serialize_val_or_type(config, name, obj):
return updated_config


def deserialize_val_or_type(config, name):
def deserialize_value_or_type(config, name):
"""Deserialize an object that can be either a value or a type and add
it to the supplied dictionary.
Expand All @@ -66,7 +66,7 @@ def deserialize_val_or_type(config, name):
See Also
--------
`serialize_val_or_type`
`serialize_value_or_type`
"""
updated_config = config.copy()
if f"{PREFIX}{name}_type" in config:
Expand Down

0 comments on commit 8f876f0

Please sign in to comment.