Skip to content

Commit 18a174b

Browse files
authored
Implements H5 legacy saving for Keras Core (#605)
* Add saved_model_test * Add extra saved model tests * Fix formatting * Add JSON utils for legacy saving * Implement H5 saving for Keras core * Change saving API routing * Fix h5 format basic tests * Ensure compile reload works with H5 format * Remove useless comments * Fix imports and formatting * Move json_utils out of saved_model folder * Add test for set_weights in optimizer * Fix formatting * Add keras options scope to replace use_legacy_config attribute * Add options scope for legacy serialization, remove circular deps * Add comments for legacy serialization code routing * Move saving/legacy to legacy/saving * Change keras options scope to use global state attr
1 parent bad4ae4 commit 18a174b

14 files changed

+1903
-334
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
"""JSON utilities for legacy saving formats (h5 and SavedModel)"""
2+
3+
import collections
4+
import enum
5+
import functools
6+
import json
7+
8+
import numpy as np
9+
10+
from keras_core.legacy.saving import serialization
11+
from keras_core.saving import serialization_lib
12+
from keras_core.utils.module_utils import tensorflow as tf
13+
14+
_EXTENSION_TYPE_SPEC = "_EXTENSION_TYPE_SPEC"
15+
16+
17+
class Encoder(json.JSONEncoder):
18+
"""JSON encoder and decoder that handles TensorShapes and tuples."""
19+
20+
def default(self, obj):
21+
"""Encodes objects for types that aren't handled by the default
22+
encoder."""
23+
if tf.available and isinstance(obj, tf.TensorShape):
24+
items = obj.as_list() if obj.rank is not None else None
25+
return {"class_name": "TensorShape", "items": items}
26+
return get_json_type(obj)
27+
28+
def encode(self, obj):
29+
return super().encode(_encode_tuple(obj))
30+
31+
32+
def _encode_tuple(x):
33+
if isinstance(x, tuple):
34+
return {
35+
"class_name": "__tuple__",
36+
"items": tuple(_encode_tuple(i) for i in x),
37+
}
38+
elif isinstance(x, list):
39+
return [_encode_tuple(i) for i in x]
40+
elif isinstance(x, dict):
41+
return {key: _encode_tuple(value) for key, value in x.items()}
42+
else:
43+
return x
44+
45+
46+
def decode(json_string):
47+
return json.loads(json_string, object_hook=_decode_helper)
48+
49+
50+
def decode_and_deserialize(
51+
json_string, module_objects=None, custom_objects=None
52+
):
53+
"""Decodes the JSON and deserializes any Keras objects found in the dict."""
54+
return json.loads(
55+
json_string,
56+
object_hook=functools.partial(
57+
_decode_helper,
58+
deserialize=True,
59+
module_objects=module_objects,
60+
custom_objects=custom_objects,
61+
),
62+
)
63+
64+
65+
def _decode_helper(
66+
obj, deserialize=False, module_objects=None, custom_objects=None
67+
):
68+
"""A decoding helper that is TF-object aware.
69+
70+
Args:
71+
obj: A decoded dictionary that may represent an object.
72+
deserialize: Boolean. When True, deserializes any Keras
73+
objects found in `obj`. Defaults to `False`.
74+
module_objects: A dictionary of built-in objects to look the name up in.
75+
Generally, `module_objects` is provided by midlevel library
76+
implementers.
77+
custom_objects: A dictionary of custom objects to look the name up in.
78+
Generally, `custom_objects` is provided by the end user.
79+
80+
Returns:
81+
The decoded object.
82+
"""
83+
if isinstance(obj, dict) and "class_name" in obj:
84+
if tf.available:
85+
if obj["class_name"] == "TensorShape":
86+
return tf.TensorShape(obj["items"])
87+
elif obj["class_name"] == "TypeSpec":
88+
from tensorflow.python.framework import type_spec_registry
89+
90+
return type_spec_registry.lookup(obj["type_spec"])._deserialize(
91+
_decode_helper(obj["serialized"])
92+
)
93+
elif obj["class_name"] == "CompositeTensor":
94+
spec = obj["spec"]
95+
tensors = []
96+
for dtype, tensor in obj["tensors"]:
97+
tensors.append(
98+
tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype))
99+
)
100+
return tf.nest.pack_sequence_as(
101+
_decode_helper(spec), tensors, expand_composites=True
102+
)
103+
104+
if obj["class_name"] == "__tuple__":
105+
return tuple(_decode_helper(i) for i in obj["items"])
106+
elif obj["class_name"] == "__ellipsis__":
107+
return Ellipsis
108+
elif deserialize and "__passive_serialization__" in obj:
109+
# __passive_serialization__ is added by the JSON encoder when
110+
# encoding an object that has a `get_config()` method.
111+
try:
112+
if (
113+
"module" not in obj
114+
): # TODO(nkovela): Add TF SavedModel scope
115+
return serialization.deserialize_keras_object(
116+
obj,
117+
module_objects=module_objects,
118+
custom_objects=custom_objects,
119+
)
120+
else:
121+
return serialization_lib.deserialize_keras_object(
122+
obj,
123+
module_objects=module_objects,
124+
custom_objects=custom_objects,
125+
)
126+
except ValueError:
127+
pass
128+
elif obj["class_name"] == "__bytes__":
129+
return obj["value"].encode("utf-8")
130+
return obj
131+
132+
133+
def get_json_type(obj):
134+
"""Serializes any object to a JSON-serializable structure.
135+
136+
Args:
137+
obj: the object to serialize
138+
139+
Returns:
140+
JSON-serializable structure representing `obj`.
141+
142+
Raises:
143+
TypeError: if `obj` cannot be serialized.
144+
"""
145+
# if obj is a serializable Keras class instance
146+
# e.g. optimizer, layer
147+
if hasattr(obj, "get_config"):
148+
# TODO(nkovela): Replace with legacy serialization
149+
serialized = serialization.serialize_keras_object(obj)
150+
serialized["__passive_serialization__"] = True
151+
return serialized
152+
153+
# if obj is any numpy type
154+
if type(obj).__module__ == np.__name__:
155+
if isinstance(obj, np.ndarray):
156+
return obj.tolist()
157+
else:
158+
return obj.item()
159+
160+
# misc functions (e.g. loss function)
161+
if callable(obj):
162+
return obj.__name__
163+
164+
# if obj is a python 'type'
165+
if type(obj).__name__ == type.__name__:
166+
return obj.__name__
167+
168+
if tf.available and isinstance(obj, tf.compat.v1.Dimension):
169+
return obj.value
170+
171+
if tf.available and isinstance(obj, tf.TensorShape):
172+
return obj.as_list()
173+
174+
if tf.available and isinstance(obj, tf.DType):
175+
return obj.name
176+
177+
if isinstance(obj, collections.abc.Mapping):
178+
return dict(obj)
179+
180+
if obj is Ellipsis:
181+
return {"class_name": "__ellipsis__"}
182+
183+
# if isinstance(obj, wrapt.ObjectProxy):
184+
# return obj.__wrapped__
185+
186+
if tf.available and isinstance(obj, tf.TypeSpec):
187+
188+
from tensorflow.python.framework import type_spec_registry
189+
190+
try:
191+
type_spec_name = type_spec_registry.get_name(type(obj))
192+
return {
193+
"class_name": "TypeSpec",
194+
"type_spec": type_spec_name,
195+
"serialized": obj._serialize(),
196+
}
197+
except ValueError:
198+
raise ValueError(
199+
f"Unable to serialize {obj} to JSON, because the TypeSpec "
200+
f"class {type(obj)} has not been registered."
201+
)
202+
if tf.available and isinstance(obj, tf.__internal__.CompositeTensor):
203+
spec = tf.type_spec_from_value(obj)
204+
tensors = []
205+
for tensor in tf.nest.flatten(obj, expand_composites=True):
206+
tensors.append((tensor.dtype.name, tensor.numpy().tolist()))
207+
return {
208+
"class_name": "CompositeTensor",
209+
"spec": get_json_type(spec),
210+
"tensors": tensors,
211+
}
212+
213+
if isinstance(obj, enum.Enum):
214+
return obj.value
215+
216+
if isinstance(obj, bytes):
217+
return {"class_name": "__bytes__", "value": obj.decode("utf-8")}
218+
219+
raise TypeError(
220+
f"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}."
221+
)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import enum
2+
3+
import pytest
4+
5+
from keras_core import backend
6+
from keras_core import testing
7+
from keras_core.legacy.saving import json_utils
8+
9+
if backend.backend() == "tensorflow":
10+
import tensorflow as tf
11+
12+
13+
class JsonUtilsTestAllBackends(testing.TestCase):
14+
def test_encode_decode_tuple(self):
15+
metadata = {"key1": (3, 5), "key2": [(1, (3, 4)), (1,)]}
16+
string = json_utils.Encoder().encode(metadata)
17+
loaded = json_utils.decode(string)
18+
19+
self.assertEqual(set(loaded.keys()), {"key1", "key2"})
20+
self.assertAllEqual(loaded["key1"], (3, 5))
21+
self.assertAllEqual(loaded["key2"], [(1, (3, 4)), (1,)])
22+
23+
def test_encode_decode_enum(self):
24+
class Enum(enum.Enum):
25+
CLASS_A = "a"
26+
CLASS_B = "b"
27+
28+
config = {"key": Enum.CLASS_A, "key2": Enum.CLASS_B}
29+
string = json_utils.Encoder().encode(config)
30+
loaded = json_utils.decode(string)
31+
self.assertAllEqual({"key": "a", "key2": "b"}, loaded)
32+
33+
def test_encode_decode_bytes(self):
34+
b_string = b"abc"
35+
json_string = json_utils.Encoder().encode(b_string)
36+
loaded = json_utils.decode(json_string)
37+
self.assertAllEqual(b_string, loaded)
38+
39+
40+
@pytest.mark.skipif(
41+
backend.backend() != "tensorflow",
42+
reason="These JSON serialization tests are specific to TF components.",
43+
)
44+
class JsonUtilsTestTF(testing.TestCase):
45+
def test_encode_decode_tensor_shape(self):
46+
metadata = {
47+
"key1": tf.TensorShape(None),
48+
"key2": [tf.TensorShape([None]), tf.TensorShape([3, None, 5])],
49+
}
50+
string = json_utils.Encoder().encode(metadata)
51+
loaded = json_utils.decode(string)
52+
53+
self.assertEqual(set(loaded.keys()), {"key1", "key2"})
54+
self.assertEqual(loaded["key1"].rank, None)
55+
self.assertAllEqual(loaded["key2"][0].as_list(), [None])
56+
self.assertAllEqual(loaded["key2"][1].as_list(), [3, None, 5])
57+
58+
def test_encode_decode_type_spec(self):
59+
spec = tf.TensorSpec((1, 5), tf.float32)
60+
string = json_utils.Encoder().encode(spec)
61+
loaded = json_utils.decode(string)
62+
self.assertEqual(spec, loaded)
63+
64+
invalid_type_spec = {
65+
"class_name": "TypeSpec",
66+
"type_spec": "Invalid Type",
67+
"serialized": None,
68+
}
69+
string = json_utils.Encoder().encode(invalid_type_spec)
70+
with self.assertRaisesRegexp(
71+
ValueError, "No TypeSpec has been registered"
72+
):
73+
loaded = json_utils.decode(string)
74+
75+
def test_encode_decode_ragged_tensor(self):
76+
x = tf.ragged.constant([[1.0, 2.0], [3.0]])
77+
string = json_utils.Encoder().encode(x)
78+
loaded = json_utils.decode(string)
79+
self.assertAllClose(loaded.values, x.values)
80+
81+
def test_encode_decode_extension_type_tensor(self):
82+
class MaskedTensor(tf.experimental.ExtensionType):
83+
__name__ = "MaskedTensor"
84+
values: tf.Tensor
85+
mask: tf.Tensor
86+
87+
x = MaskedTensor(
88+
values=[[1, 2, 3], [4, 5, 6]],
89+
mask=[[True, True, False], [True, False, True]],
90+
)
91+
string = json_utils.Encoder().encode(x)
92+
loaded = json_utils.decode(string)
93+
self.assertAllClose(loaded.values, x.values)
94+
self.assertAllClose(loaded.mask, x.mask)

0 commit comments

Comments
 (0)