Skip to content

Commit

Permalink
improve tests for transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Sep 13, 2024
1 parent 175f89d commit 91d9b8b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
16 changes: 14 additions & 2 deletions bayesflow/data_adapters/transforms/lambda_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

@serializable(package="bayesflow.data_adapters")
class LambdaTransform(Transform):
"""
Transforms a parameter using a pair of forward and inverse functions.
Important note: This class is only serializable if the forward and inverse functions are serializable.
This most likely means you will have to pass the scope that the forward and inverse functions are contained in
to the `custom_objects` argument of the `deserialize` function when deserializing this class.
"""

def __init__(self, parameter_name: str, forward: callable, inverse: callable):
super().__init__(parameter_name)

Expand All @@ -16,11 +24,15 @@ def __init__(self, parameter_name: str, forward: callable, inverse: callable):

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "LambdaTransform":
return cls(config["parameter_name"], deserialize(config["forward"]), deserialize(config["inverse"]))
return cls(
deserialize(config["parameter_name"], custom_objects),
deserialize(config["forward"], custom_objects),
deserialize(config["inverse"], custom_objects),
)

def get_config(self) -> dict:
return {
"parameter_name": self.parameter_name,
"parameter_name": serialize(self.parameter_name),
"forward": serialize(self.forward),
"inverse": serialize(self.inverse),
}
19 changes: 19 additions & 0 deletions tests/test_data_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,32 @@
import pytest


def forward_transform(x):
return x + 1


def inverse_transform(x):
return x - 1


@pytest.fixture()
def custom_objects():
return globals() | np.__dict__


@pytest.fixture()
def data_adapter():
from bayesflow.data_adapters import ConcatenateKeysDataAdapter
from bayesflow.data_adapters.transforms import LambdaTransform, Normalize

return ConcatenateKeysDataAdapter(
x=["x1", "x2"],
y=["y1", "y2"],
transforms=[
Normalize("x1", mean=np.array([0.0]), std=np.array([1.0])),
# use a lambda transform with global functions
LambdaTransform("x2", forward_transform, inverse_transform),
],
)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_data_adapters/test_data_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def test_cycle_consistency(data_adapter, random_data):
assert keras.ops.all(keras.ops.isclose(value, deprocessed[key]))


def test_serialize_deserialize(data_adapter):
def test_serialize_deserialize(data_adapter, custom_objects):
serialized = serialize(data_adapter)
deserialized = deserialize(serialized)
deserialized = deserialize(serialized, custom_objects)
reserialized = serialize(deserialized)

assert reserialized == serialized

0 comments on commit 91d9b8b

Please sign in to comment.