Skip to content

Commit e691644

Browse files
authored
Adapter: Minor fixes and more thorough testing (#317)
* FilterTransform: deseralize kwargs and transform_constructor * The deserialize call for `kwargs` was missing. * The constructor can be (de)serialized using get_registered_name and get_registered_object Added meaningful error message, if functions are not registered and not passed as custom objects. * transforms: ensure serialization cycle consistency Keras deserializes tuples as lists. This can cause problems, and has to be undone manually. * adapter: expand tests to include all transforms * also, expand test to ensure consistency of outputs before and after serialization
1 parent 6ff721c commit e691644

File tree

5 files changed

+67
-18
lines changed

5 files changed

+67
-18
lines changed

bayesflow/adapters/transforms/broadcast.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,19 @@ def __init__(
7979

8080
@classmethod
8181
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
82+
# Deserialize turns tuples to lists, undo it if necessary
83+
exclude = deserialize(config["exclude"], custom_objects)
84+
exclude = tuple(exclude) if isinstance(exclude, list) else exclude
85+
expand = deserialize(config["expand"], custom_objects)
86+
expand = tuple(expand) if isinstance(expand, list) else expand
87+
squeeze = deserialize(config["squeeze"], custom_objects)
88+
squeeze = tuple(squeeze) if isinstance(squeeze, list) else squeeze
8289
return cls(
8390
keys=deserialize(config["keys"], custom_objects),
8491
to=deserialize(config["to"], custom_objects),
85-
expand=deserialize(config["expand"], custom_objects),
86-
exclude=deserialize(config["exclude"], custom_objects),
87-
squeeze=deserialize(config["squeeze"], custom_objects),
92+
expand=expand,
93+
exclude=exclude,
94+
squeeze=squeeze,
8895
)
8996

9097
def get_config(self) -> dict:

bayesflow/adapters/transforms/filter_transform.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
from keras.saving import (
66
deserialize_keras_object as deserialize,
7+
get_registered_name,
8+
get_registered_object,
79
register_keras_serializable as serializable,
810
serialize_keras_object as serialize,
911
)
@@ -79,21 +81,33 @@ def extra_repr(self) -> str:
7981

8082
@classmethod
8183
def from_config(cls, config: dict, custom_objects=None) -> "Transform":
82-
def transform_constructor(*args, **kwargs):
83-
raise RuntimeError(
84-
"Instantiating new elementwise transforms on a deserialized FilterTransform is not yet supported (and"
85-
"may never be). As a work-around, you can manually register the elementwise transform constructor after"
86-
"deserialization:\n"
87-
"obj = deserialize(config)\n"
88-
"obj.transform_constructor = MyElementwiseTransform"
89-
)
90-
84+
transform_constructor = get_registered_object(config["transform_constructor"])
85+
try:
86+
kwargs = deserialize(config["kwargs"])
87+
except TypeError as e:
88+
if transform_constructor.__name__ == "LambdaTransform":
89+
raise TypeError(
90+
"LambdaTransform (created by Adapter.apply) could not be deserialized.\n"
91+
"This is probably because the custom transform functions `forward` and "
92+
"`backward` from `Adapter.apply` were not passed as `custom_objects`.\n"
93+
"For example, if your adapter uses\n"
94+
"`Adapter.apply(forward=forward_transform, inverse=inverse_transform)`,\n"
95+
"you have to pass\n"
96+
'`custom_objects={"forward_transform": forward_transform, '
97+
'"inverse_transform": inverse_transform}`\n'
98+
"to the function you use to load the serialized object."
99+
) from e
100+
raise TypeError(
101+
"The transform could not be deserialized properly. "
102+
"The most likely reason is that some classes or functions "
103+
"are not known during deserialization. Please pass them as `custom_objects`."
104+
) from e
91105
instance = cls(
92106
transform_constructor=transform_constructor,
93107
predicate=deserialize(config["predicate"], custom_objects),
94108
include=deserialize(config["include"], custom_objects),
95109
exclude=deserialize(config["exclude"], custom_objects),
96-
**config["kwargs"],
110+
**kwargs,
97111
)
98112

99113
instance.transform_map = deserialize(config["transform_map"])
@@ -102,6 +116,7 @@ def transform_constructor(*args, **kwargs):
102116

103117
def get_config(self) -> dict:
104118
return {
119+
"transform_constructor": get_registered_name(self.transform_constructor),
105120
"predicate": serialize(self.predicate),
106121
"include": serialize(self.include),
107122
"exclude": serialize(self.exclude),

bayesflow/adapters/transforms/standardize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@ def __init__(
4141

4242
@classmethod
4343
def from_config(cls, config: dict, custom_objects=None) -> "Standardize":
44+
# Deserialize turns tuples to lists, undo it if necessary
45+
deserialized_axis = deserialize(config["axis"], custom_objects)
46+
if isinstance(deserialized_axis, list):
47+
deserialized_axis = tuple(deserialized_axis)
4448
return cls(
4549
mean=deserialize(config["mean"], custom_objects),
4650
std=deserialize(config["std"], custom_objects),
47-
axis=deserialize(config["axis"], custom_objects),
51+
axis=deserialized_axis,
4852
momentum=deserialize(config["momentum"], custom_objects),
4953
)
5054

tests/test_adapters/conftest.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def inverse_transform(x):
1212

1313
@pytest.fixture()
1414
def custom_objects():
15-
return globals() | np.__dict__
15+
return dict(forward_transform=forward_transform, inverse_transform=inverse_transform)
1616

1717

1818
@pytest.fixture()
@@ -22,15 +22,22 @@ def adapter():
2222
d = (
2323
Adapter()
2424
.to_array()
25-
.convert_dtype("float64", "float32")
25+
.as_set(["s1", "s2"])
26+
.broadcast("t1", to="t2")
27+
.as_time_series(["t1", "t2"])
28+
.convert_dtype("float64", "float32", exclude="o1")
2629
.concatenate(["x1", "x2"], into="x")
2730
.concatenate(["y1", "y2"], into="y")
2831
.expand_dims(["z1"], axis=2)
2932
.apply(forward=forward_transform, inverse=inverse_transform)
3033
# TODO: fix this in keras
3134
# .apply(include="p1", forward=np.log, inverse=np.exp)
3235
.constrain("p2", lower=0)
33-
.standardize()
36+
.standardize(exclude=["t1", "t2", "o1"])
37+
.drop("d1")
38+
.one_hot("o1", 10)
39+
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1"])
40+
.rename("o1", "o2")
3441
)
3542

3643
return d
@@ -46,4 +53,11 @@ def random_data():
4653
"z1": np.random.standard_normal(size=(32, 2)),
4754
"p1": np.random.lognormal(size=(32, 2)),
4855
"p2": np.random.lognormal(size=(32, 2)),
56+
"s1": np.random.standard_normal(size=(32, 3, 2)),
57+
"s2": np.random.standard_normal(size=(32, 3, 2)),
58+
"t1": np.zeros((3, 2)),
59+
"t2": np.ones((32, 3, 2)),
60+
"d1": np.random.standard_normal(size=(32, 2)),
61+
"d2": np.random.standard_normal(size=(32, 2)),
62+
"o1": np.random.randint(0, 9, size=(32, 2)),
4963
}

tests/test_adapters/test_adapters.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,24 @@ def test_cycle_consistency(adapter, random_data):
1010
deprocessed = adapter(processed, inverse=True)
1111

1212
for key, value in random_data.items():
13+
if key in ["d1", "d2"]:
14+
# dropped
15+
continue
1316
assert key in deprocessed
1417
assert np.allclose(value, deprocessed[key])
1518

1619

17-
def test_serialize_deserialize(adapter, custom_objects):
20+
def test_serialize_deserialize(adapter, custom_objects, random_data):
21+
processed = adapter(random_data)
1822
serialized = serialize(adapter)
1923
deserialized = deserialize(serialized, custom_objects)
2024
reserialized = serialize(deserialized)
2125

2226
assert reserialized.keys() == serialized.keys()
2327
for key in reserialized:
2428
assert reserialized[key] == serialized[key]
29+
30+
random_data["foo"] = random_data["x1"]
31+
deserialized_processed = deserialized(random_data)
32+
for key, value in processed.items():
33+
assert np.allclose(value, deserialized_processed[key])

0 commit comments

Comments
 (0)