Skip to content

Commit

Permalink
adapter: expand tests to include all transforms
Browse files Browse the repository at this point in the history
* also, expand test to ensure consistency of outputs before and after
  serialization
  • Loading branch information
vpratz committed Feb 13, 2025
1 parent 17083e4 commit 6389db6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
20 changes: 17 additions & 3 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def inverse_transform(x):

@pytest.fixture()
def custom_objects():
return globals() | np.__dict__
return dict(forward_transform=forward_transform, inverse_transform=inverse_transform)


@pytest.fixture()
Expand All @@ -22,15 +22,22 @@ def adapter():
d = (
Adapter()
.to_array()
.convert_dtype("float64", "float32")
.as_set(["s1", "s2"])
.broadcast("t1", to="t2")
.as_time_series(["t1", "t2"])
.convert_dtype("float64", "float32", exclude="o1")
.concatenate(["x1", "x2"], into="x")
.concatenate(["y1", "y2"], into="y")
.expand_dims(["z1"], axis=2)
.apply(forward=forward_transform, inverse=inverse_transform)
# TODO: fix this in keras
# .apply(include="p1", forward=np.log, inverse=np.exp)
.constrain("p2", lower=0)
.standardize()
.standardize(exclude=["t1", "t2", "o1"])
.drop("d1")
.one_hot("o1", 10)
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1"])
.rename("o1", "o2")
)

return d
Expand All @@ -46,4 +53,11 @@ def random_data():
"z1": np.random.standard_normal(size=(32, 2)),
"p1": np.random.lognormal(size=(32, 2)),
"p2": np.random.lognormal(size=(32, 2)),
"s1": np.random.standard_normal(size=(32, 3, 2)),
"s2": np.random.standard_normal(size=(32, 3, 2)),
"t1": np.zeros((3, 2)),
"t2": np.ones((32, 3, 2)),
"d1": np.random.standard_normal(size=(32, 2)),
"d2": np.random.standard_normal(size=(32, 2)),
"o1": np.random.randint(0, 9, size=(32, 2)),
}
11 changes: 10 additions & 1 deletion tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,24 @@ def test_cycle_consistency(adapter, random_data):
deprocessed = adapter(processed, inverse=True)

for key, value in random_data.items():
if key in ["d1", "d2"]:
# dropped
continue
assert key in deprocessed
assert np.allclose(value, deprocessed[key])


def test_serialize_deserialize(adapter, custom_objects):
def test_serialize_deserialize(adapter, custom_objects, random_data):
processed = adapter(random_data)
serialized = serialize(adapter)
deserialized = deserialize(serialized, custom_objects)
reserialized = serialize(deserialized)

assert reserialized.keys() == serialized.keys()
for key in reserialized:
assert reserialized[key] == serialized[key]

random_data["foo"] = random_data["x1"]
deserialized_processed = deserialized(random_data)
for key, value in processed.items():
assert np.allclose(value, deserialized_processed[key])

0 comments on commit 6389db6

Please sign in to comment.