Skip to content

Commit 9820b2d

Browse files
authored
feat: Support npz archives in NumpyDeserializer (#3799)
1 parent 17a4145 commit 9820b2d

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/sagemaker/deserializers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ def deserialize(self, stream, content_type):
189189

190190

191191
class NumpyDeserializer(SimpleBaseDeserializer):
192-
"""Deserialize a stream of data in .npy or UTF-8 CSV/JSON format to a numpy array."""
192+
"""Deserialize a stream of data in .npy, .npz or UTF-8 CSV/JSON format to a numpy array.
193+
194+
Note that when using application/x-npz archive format, the result will usually be a
195+
dictionary-like object containing multiple arrays (as per ``numpy.load()``) - instead of a
196+
single array.
197+
"""
193198

194199
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
195200
"""Initialize a ``NumpyDeserializer`` instance.
@@ -223,6 +228,11 @@ def deserialize(self, stream, content_type):
223228
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
224229
if content_type == "application/x-npy":
225230
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
231+
if content_type == "application/x-npz":
232+
try:
233+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
234+
finally:
235+
stream.close()
226236
finally:
227237
stream.close()
228238

tests/unit/sagemaker/test_deserializers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,23 @@ def test_numpy_deserializer_from_npy_object_array_with_allow_pickle_false():
164164
numpy_deserializer.deserialize(stream, "application/x-npy")
165165

166166

167+
def test_numpy_deserializer_from_npz(numpy_deserializer):
168+
arrays = {
169+
"alpha": np.ones((2, 3)),
170+
"beta": np.zeros((3, 2)),
171+
}
172+
stream = io.BytesIO()
173+
np.savez_compressed(stream, **arrays)
174+
stream.seek(0)
175+
176+
result = numpy_deserializer.deserialize(stream, "application/x-npz")
177+
178+
assert isinstance(result, np.lib.npyio.NpzFile)
179+
assert set(arrays.keys()) == set(result.keys())
180+
for key, arr in arrays.items():
181+
assert np.array_equal(arr, result[key])
182+
183+
167184
@pytest.fixture
168185
def json_deserializer():
169186
return JSONDeserializer()

0 commit comments

Comments
 (0)