Skip to content

Commit 0a63751

Browse files
committed
Add tests
1 parent c14db02 commit 0a63751

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

tests/test_h5.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import pytest
2+
import tempfile
3+
import numpy as np
4+
import json
5+
import os
6+
from pathlib import Path
7+
from addict import Dict
8+
import h5py
9+
from cadet import H5
10+
from cadet.h5 import recursively_save, recursively_load, convert_from_numpy, recursively_load_dict
11+
12+
13+
@pytest.fixture
14+
def h5_instance():
15+
return H5({
16+
"keyString": "value1",
17+
"keyInt": 42,
18+
"keyArray": np.array([1, 2, 3]),
19+
"keyNone": None,
20+
"keyDict": {
21+
"nestedKeyFloat": 12.345,
22+
"nestedKeyList": [1, 2, 3, 4],
23+
"nestedKeyNone": None,
24+
}
25+
})
26+
27+
28+
@pytest.fixture
29+
def temp_h5_file():
30+
with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as tmp:
31+
yield tmp.name
32+
os.remove(tmp.name)
33+
34+
35+
@pytest.fixture
36+
def temp_json_file():
37+
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp:
38+
yield tmp.name
39+
os.remove(tmp.name)
40+
41+
42+
def test_init(h5_instance):
43+
assert isinstance(h5_instance.root, Dict)
44+
assert h5_instance.root.keyString == "value1"
45+
assert h5_instance.root.keyInt == 42
46+
47+
48+
def test_save_and_load_h5(h5_instance, temp_h5_file):
49+
h5_instance.filename = temp_h5_file
50+
h5_instance.save()
51+
52+
new_instance = H5()
53+
new_instance.filename = temp_h5_file
54+
new_instance.load()
55+
56+
assert new_instance.root.keyString == b"value1"
57+
assert new_instance.root.keyInt == 42
58+
assert "keyNone" not in new_instance.root
59+
assert all(new_instance.root.keyDict["nestedKeyList"] == [1, 2, 3, 4])
60+
assert "nestedKeyNone" not in new_instance.root.keyDict
61+
assert np.array_equal(new_instance.root.keyArray, h5_instance.root.keyArray)
62+
63+
64+
def test_save_and_load_json(h5_instance, temp_json_file):
65+
h5_instance.save_json(temp_json_file)
66+
67+
new_instance = H5()
68+
new_instance.load_json(temp_json_file)
69+
70+
assert new_instance.root.keyString == "value1"
71+
assert new_instance.root.keyInt == 42
72+
assert new_instance.root.keyArray == [1, 2, 3]
73+
74+
75+
def test_append_data(h5_instance, temp_h5_file):
76+
h5_instance.filename = temp_h5_file
77+
h5_instance.save()
78+
79+
h5_instance["key4"] = "new_value"
80+
81+
with pytest.raises(KeyError):
82+
# This correctly raises a KeyError because h5_instance still contains
83+
# e.g. keyString and .append would try to over-write keyString
84+
h5_instance.append()
85+
86+
addition_h5_instance = H5()
87+
addition_h5_instance.filename = temp_h5_file
88+
89+
addition_h5_instance["key4"] = "new_value"
90+
addition_h5_instance.append()
91+
92+
new_instance = H5()
93+
new_instance.filename = temp_h5_file
94+
new_instance.load()
95+
96+
assert new_instance.root.key4 == b"new_value"
97+
98+
99+
def test_update(h5_instance):
100+
other_instance = H5({"keyInt": 100, "key4": "added"})
101+
h5_instance.update(other_instance)
102+
103+
assert h5_instance.root.keyInt == 100
104+
assert h5_instance.root.key4 == "added"
105+
106+
107+
def test_recursively_save_and_load(h5_instance, temp_h5_file):
108+
data = Dict({"group": {"dataset": np.array([10, 20, 30])}})
109+
110+
with h5py.File(temp_h5_file, "w") as h5file:
111+
recursively_save(h5file, "/", data, lambda x: x)
112+
113+
with h5py.File(temp_h5_file, "r") as h5file:
114+
loaded_data = recursively_load(h5file, "/", lambda x: x, None)
115+
116+
assert np.array_equal(loaded_data["group"]["dataset"], np.array([10, 20, 30]))
117+
118+
119+
def test_transform_methods():
120+
instance = H5()
121+
data = np.array([1, 2, 3])
122+
123+
transformed = instance.transform(data)
124+
inverse_transformed = instance.inverse_transform(transformed)
125+
126+
assert np.array_equal(inverse_transformed, data)
127+
128+
129+
def test_convert_from_numpy():
130+
data = Dict({"array": np.array([1, 2, 3]), "scalar": np.int32(10)})
131+
converted = convert_from_numpy(data)
132+
133+
assert converted["array"] == [1, 2, 3]
134+
assert converted["scalar"] == 10
135+
136+
137+
def test_recursively_load_dict():
138+
data = {"nested": {"value": np.int32(42), "bytes": b"text"}}
139+
loaded = recursively_load_dict(data)
140+
141+
assert loaded.nested.value == 42
142+
assert loaded.nested.bytes == "text"
143+
144+
145+
def test_get_set_item(h5_instance):
146+
h5_instance["key4"] = "test_value"
147+
assert h5_instance["key4"] == "test_value"
148+
149+
h5_instance["nested/key5"] = 123
150+
assert h5_instance["nested/key5"] == 123
151+
152+
153+
def test_string_representation(h5_instance):
154+
representation = str(h5_instance)
155+
assert "Filename = None" in representation
156+
assert "keyString" in representation
157+
assert "keyInt" in representation
158+
159+
160+
def test_load_nonexistent_file():
161+
instance = H5()
162+
instance.filename = "nonexistent_file.h5"
163+
with pytest.raises(OSError):
164+
instance.load()
165+
166+
167+
# def test_save_without_filename(h5_instance):
168+
# with pytest.raises(ValueError):
169+
# h5_instance.save()
170+
171+
172+
def test_load_json_with_invalid_data(temp_json_file):
173+
invalid_data = "{invalid_json: true}"
174+
Path(temp_json_file).write_text(invalid_data)
175+
176+
instance = H5()
177+
with pytest.raises(json.JSONDecodeError):
178+
instance.load_json(temp_json_file)

0 commit comments

Comments
 (0)