diff --git a/bayesflow/data_adapters/transforms/lambda_transform.py b/bayesflow/data_adapters/transforms/lambda_transform.py index 5b1bf2395..a2552069d 100644 --- a/bayesflow/data_adapters/transforms/lambda_transform.py +++ b/bayesflow/data_adapters/transforms/lambda_transform.py @@ -8,6 +8,14 @@ @serializable(package="bayesflow.data_adapters") class LambdaTransform(Transform): + """ + Transforms a parameter using a pair of forward and inverse functions. + + Important note: This class is only serializable if the forward and inverse functions are serializable. + This most likely means you will have to pass the scope that the forward and inverse functions are contained in + to the `custom_objects` argument of the `deserialize` function when deserializing this class. + """ + def __init__(self, parameter_name: str, forward: callable, inverse: callable): super().__init__(parameter_name) @@ -16,11 +24,15 @@ def __init__(self, parameter_name: str, forward: callable, inverse: callable): @classmethod def from_config(cls, config: dict, custom_objects=None) -> "LambdaTransform": - return cls(config["parameter_name"], deserialize(config["forward"]), deserialize(config["inverse"])) + return cls( + deserialize(config["parameter_name"], custom_objects), + deserialize(config["forward"], custom_objects), + deserialize(config["inverse"], custom_objects), + ) def get_config(self) -> dict: return { - "parameter_name": self.parameter_name, + "parameter_name": serialize(self.parameter_name), "forward": serialize(self.forward), "inverse": serialize(self.inverse), } diff --git a/tests/test_data_adapters/conftest.py b/tests/test_data_adapters/conftest.py index ec7365cb5..0cdcb0e93 100644 --- a/tests/test_data_adapters/conftest.py +++ b/tests/test_data_adapters/conftest.py @@ -2,13 +2,32 @@ import pytest +def forward_transform(x): + return x + 1 + + +def inverse_transform(x): + return x - 1 + + +@pytest.fixture() +def custom_objects(): + return globals() | np.__dict__ + + @pytest.fixture() def data_adapter(): from bayesflow.data_adapters import ConcatenateKeysDataAdapter + from bayesflow.data_adapters.transforms import LambdaTransform, Normalize return ConcatenateKeysDataAdapter( x=["x1", "x2"], y=["y1", "y2"], + transforms=[ + Normalize("x1", mean=np.array([0.0]), std=np.array([1.0])), + # use a lambda transform with global functions + LambdaTransform("x2", forward_transform, inverse_transform), + ], ) diff --git a/tests/test_data_adapters/test_data_adapters.py b/tests/test_data_adapters/test_data_adapters.py index 34cb224f0..f5ff155d9 100644 --- a/tests/test_data_adapters/test_data_adapters.py +++ b/tests/test_data_adapters/test_data_adapters.py @@ -14,9 +14,9 @@ def test_cycle_consistency(data_adapter, random_data): assert keras.ops.all(keras.ops.isclose(value, deprocessed[key])) -def test_serialize_deserialize(data_adapter): +def test_serialize_deserialize(data_adapter, custom_objects): serialized = serialize(data_adapter) - deserialized = deserialize(serialized) + deserialized = deserialize(serialized, custom_objects) reserialized = serialize(deserialized) assert reserialized == serialized