Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change serializer #231

Merged
merged 8 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/integration_tests/test_zn_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
import os
import pathlib
import shutil
import subprocess

import pytest

from zntrack import zn
from zntrack.core.base import Node


@pytest.fixture
def proj_path(tmp_path):
shutil.copy(__file__, tmp_path)
os.chdir(tmp_path)
subprocess.check_call(["git", "init"])
subprocess.check_call(["dvc", "init"])

return tmp_path


class FirstNode(Node):
outs = zn.outs()

def run(self):
self.outs = 42


class LastNode(Node):
first_node: FirstNode = zn.deps(FirstNode.load())
outs = zn.outs()

def run(self):
self.outs = self.first_node.outs / 2


def test_base_run(proj_path):
FirstNode().write_graph(run=True)
LastNode().write_graph(run=True)

assert LastNode.load().outs == 21


@pytest.fixture()
def zntrack_dict() -> dict:
return {
"LastNode": {
"first_node": {
"_type": "ZnTrackType",
"value": {
"cls": "FirstNode",
"module": "test_zn_deps",
"name": "FirstNode",
},
}
}
}


def test_assert_write_file(proj_path, zntrack_dict):
FirstNode().write_graph()
LastNode().write_graph()

zntrack_dict_loaded = json.loads(pathlib.Path("zntrack.json").read_text())

assert zntrack_dict_loaded == zntrack_dict


def test_assert_read_file(proj_path, zntrack_dict):
pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict))

assert isinstance(LastNode.load().first_node, FirstNode)
72 changes: 67 additions & 5 deletions tests/integration_tests/test_zn_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ def test_created_files(proj_path):

assert zntrack_dict["SingleNode"]["data_class"] == {
"_type": "zn.method",
"value": {
"module": "test_zn_methods",
"cls": "ExampleMethod",
},
"value": {"module": "test_zn_methods"},
}
assert params_dict["SingleNode"]["data_class"] == {
"param1": 1,
"param2": 2,
"_cls": "ExampleMethod",
}
assert params_dict["SingleNode"]["data_class"] == {"param1": 1, "param2": 2}


class SingleNodeNoParams(Node):
Expand Down Expand Up @@ -162,3 +163,64 @@ def test_write_params_no_kwargs(proj_path):

dvc_dict = yaml.safe_load(pathlib.Path("dvc.yaml").read_text())
assert dvc_dict["stages"]["SingleNodeNoParams"]["params"] == ["SingleNodeNoParams"]


@pytest.fixture()
def zntrack_params_dict() -> (dict, dict):
zntrack_dict = {
"SingleNodeNoParams": {
"data_class": {"_type": "zn.method", "value": {"module": "test_zn_methods"}}
}
}
params_dict = {
"SingleNodeNoParams": {
"data_class": {"_cls": "ExampleMethod", "param1": 1, "param2": 2}
}
}
return zntrack_dict, params_dict


def test_assert_write_files(proj_path, zntrack_params_dict):
"""Test the written files (without mocking pathlibs write_text)"""
SingleNodeNoParams(data_class=ExampleMethod(1, 2)).write_graph()

zntrack_dict = json.loads(pathlib.Path("zntrack.json").read_text())
params_dict = yaml.safe_load(pathlib.Path("params.yaml").read_text())

assert zntrack_dict == zntrack_params_dict[0]
assert params_dict == zntrack_params_dict[1]


def test_assert_read_files(proj_path, zntrack_params_dict):
"""Test the written files (without mocking pathlibs write_text)"""
zntrack_dict, params_dict = zntrack_params_dict

pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict))
pathlib.Path("params.yaml").write_text(yaml.safe_dump(params_dict))

node = SingleNodeNoParams.load()
assert node.data_class.param1 == 1
assert node.data_class.param2 == 2


def test_assert_read_files_old1(proj_path):
"""Test the written files (without mocking pathlibs write_text)

Test for versions before https://github.com/zincware/ZnTrack/pull/231
"""
zntrack_dict = {
"SingleNodeNoParams": {
"data_class": {
"_type": "zn.method",
"value": {"cls": "ExampleMethod", "module": "test_zn_methods"},
}
}
}
params_dict = {"SingleNodeNoParams": {"data_class": {"param1": 1, "param2": 2}}}

pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict))
pathlib.Path("params.yaml").write_text(yaml.safe_dump(params_dict))

node = SingleNodeNoParams.load()
assert node.data_class.param1 == 1
assert node.data_class.param2 == 2
78 changes: 78 additions & 0 deletions tests/unit_tests/zn/test_zn_options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import dataclasses
import json
import pathlib

import znjson

from zntrack import zn
from zntrack.zn.split_option import combine_values, split_value


class ExampleClass:
Expand Down Expand Up @@ -30,3 +37,74 @@ def test_zn_plots():
# test save and load if there is nothing to save or load
assert ExamplePlots.plots.save(example) is None
assert ExamplePlots.plots.load(example) is None


@dataclasses.dataclass
class ExampleDataClass:
a: int = 5
b: int = 7

# make it a zn.Method
znjson_zn_method = True

def __eq__(self, other):
return (other.a == self.a) and (other.b == self.b)


def test_split_value():
serialized_value = json.loads(json.dumps(ExampleDataClass(), cls=znjson.ZnEncoder))

params_data, zntrack_data = split_value(serialized_value)
assert zntrack_data == {"_type": "zn.method", "value": {"module": "test_zn_options"}}
assert params_data == {"_cls": "ExampleDataClass", "a": 5, "b": 7}

# and now test the same thing but serialize a list
serialized_value = json.loads(json.dumps([ExampleDataClass()], cls=znjson.ZnEncoder))
params_data, zntrack_data = split_value(serialized_value)
assert zntrack_data == [
{"_type": "zn.method", "value": {"module": "test_zn_options"}}
]
assert params_data == ({"_cls": "ExampleDataClass", "a": 5, "b": 7},)


def test_combine_values():
zntrack_data = {"_type": "zn.method", "value": {"module": "test_zn_options"}}
params_data = {"_cls": "ExampleDataClass", "a": 5, "b": 7}

assert combine_values(zntrack_data, params_data) == ExampleDataClass()

# try older data structure
zntrack_data = {
"_type": "zn.method",
"value": {
"module": "test_zn_options",
"cls": "ExampleDataClass",
},
}
params_data = {"a": 5, "b": 7}
assert combine_values(zntrack_data, params_data) == ExampleDataClass()

# try older data structure
zntrack_data = {
"_type": "zn.method",
"value": {
"module": "test_zn_options",
"name": "ExampleDataClass",
},
}
params_data = {"a": 5, "b": 7}
assert combine_values(zntrack_data, params_data) == ExampleDataClass()


def test_split_value_path():
path = pathlib.Path("my_path")
serialized_value = json.loads(json.dumps(path, cls=znjson.ZnEncoder))

params_data, zntrack_data = split_value(serialized_value)

assert params_data == "my_path"
assert zntrack_data == {"_type": "pathlib.Path"}

new_path = combine_values(zntrack_data, params_data)
# TODO change order to be consistent with split_values
assert new_path == path
2 changes: 2 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
cwd_temp_dir,
decode_dict,
deprecated,
encode_dict,
get_python_interpreter,
module_handler,
module_to_path,
Expand All @@ -27,6 +28,7 @@
"config",
"cwd_temp_dir",
"decode_dict",
"encode_dict",
"module_handler",
"update_nb_name",
"module_to_path",
Expand Down
5 changes: 5 additions & 0 deletions zntrack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def decode_dict(value):
return json.loads(json.dumps(value), cls=znjson.ZnDecoder)


def encode_dict(value) -> dict:
"""Encode value into a dict serialized with ZnJson"""
return json.loads(json.dumps(value, cls=znjson.ZnEncoder))


def get_auto_init(fields: typing.List[str]):
"""Automatically create a __init__ based on fields
Parameters
Expand Down
Loading