Skip to content

Commit cc24f0d

Browse files
committed
Round trip serialization for array metadata v2/v3
1 parent 370eb8b commit cc24f0d

File tree

4 files changed

+27
-159
lines changed

4 files changed

+27
-159
lines changed

src/zarr/core/metadata/v2.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
170170
if dtype.kind in "SV":
171171
fill_value_encoded = _data.get("fill_value")
172172
if fill_value_encoded is not None:
173-
fill_value = base64.standard_b64decode(fill_value_encoded)
173+
fill_value: Any = base64.standard_b64decode(fill_value_encoded)
174174
_data["fill_value"] = fill_value
175175
else:
176176
fill_value = _data.get("fill_value")
@@ -180,13 +180,11 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
180180
_data["fill_value"] = np.array("NaT", dtype=dtype)[()]
181181
else:
182182
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
183-
elif dtype.kind == "c" and isinstance(fill_value, list):
184-
if len(fill_value) == 2:
185-
val = complex(float(fill_value[0]), float(fill_value[1]))
186-
_data["fill_value"] = np.array(val, dtype=dtype)[()]
187-
elif dtype.kind in "f" and isinstance(fill_value, str):
188-
if fill_value in {"NaN", "Infinity", "-Infinity"}:
189-
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
183+
elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2:
184+
val = complex(float(fill_value[0]), float(fill_value[1]))
185+
_data["fill_value"] = np.array(val, dtype=dtype)[()]
186+
elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}:
187+
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
190188
# zarr v2 allowed arbitrary keys in the metadata.
191189
# Filter the keys to only those expected by the constructor.
192190
expected = {x.name for x in fields(cls)}
@@ -196,21 +194,22 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
196194
return cls(**_data)
197195

198196
def to_dict(self) -> dict[str, JSON]:
199-
def _sanitize_fill_value(fv: Any):
197+
def _sanitize_fill_value(fv: Any) -> JSON:
200198
if fv is None:
201199
return fv
202200
elif isinstance(fv, np.datetime64):
203201
if np.isnat(fv):
204202
return "NaT"
205203
return np.datetime_as_string(fv)
206204
elif isinstance(fv, numbers.Real):
207-
if np.isnan(fv):
205+
float_fv = float(fv)
206+
if np.isnan(float_fv):
208207
fv = "NaN"
209-
elif np.isinf(fv):
210-
fv = "Infinity" if fv > 0 else "-Infinity"
208+
elif np.isinf(float_fv):
209+
fv = "Infinity" if float_fv > 0 else "-Infinity"
211210
elif isinstance(fv, numbers.Complex):
212211
fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)]
213-
return fv
212+
return cast(JSON, fv)
214213

215214
zarray_dict = super().to_dict()
216215

src/zarr/testing/stateful.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def add_group(self, name: str, data: DataObject) -> None:
8585
@rule(
8686
data=st.data(),
8787
name=node_names,
88-
array_and_chunks=np_array_and_chunks(nparrays=numpy_arrays(zarr_formats=st.just(3))),
88+
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
8989
)
9090
def add_array(
9191
self,

src/zarr/testing/strategies.py

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
6-
import numcodecs
76
import numpy as np
87
from hypothesis import assume, given, settings # noqa: F401
98
from hypothesis.strategies import SearchStrategy
@@ -345,136 +344,3 @@ def make_request(start: int, length: int) -> RangeByteRequest:
345344
)
346345
key_tuple = st.tuples(keys, byte_ranges)
347346
return st.lists(key_tuple, min_size=1, max_size=10)
348-
349-
350-
def simple_text():
351-
"""A strategy for generating simple text strings."""
352-
return st.text(st.characters(min_codepoint=32, max_codepoint=126), min_size=1, max_size=10)
353-
354-
355-
def simple_attrs():
356-
"""A strategy for generating simple attribute dictionaries."""
357-
return st.dictionaries(
358-
simple_text(),
359-
st.one_of(
360-
st.integers(),
361-
st.floats(allow_nan=False, allow_infinity=False),
362-
st.booleans(),
363-
simple_text(),
364-
),
365-
)
366-
367-
368-
def array_shapes(min_dims=1, max_dims=3, max_len=100):
369-
"""A strategy for generating array shapes."""
370-
return st.lists(
371-
st.integers(min_value=1, max_value=max_len), min_size=min_dims, max_size=max_dims
372-
)
373-
374-
375-
# def zarr_compressors():
376-
# """A strategy for generating Zarr compressors."""
377-
# return st.sampled_from([None, Blosc(), GZip(), Zstd(), LZ4()])
378-
379-
380-
# def zarr_codecs():
381-
# """A strategy for generating Zarr codecs."""
382-
# return st.sampled_from([BytesCodec(), Blosc(), GZip(), Zstd(), LZ4()])
383-
384-
385-
def zarr_filters():
386-
"""A strategy for generating Zarr filters."""
387-
return st.lists(
388-
st.just(numcodecs.Delta(dtype="i4")), min_size=0, max_size=2
389-
) # Example filter, expand as needed
390-
391-
392-
def zarr_storage_transformers():
393-
"""A strategy for generating Zarr storage transformers."""
394-
return st.lists(
395-
st.dictionaries(
396-
simple_text(), st.one_of(st.integers(), st.floats(), st.booleans(), simple_text())
397-
),
398-
min_size=0,
399-
max_size=2,
400-
)
401-
402-
403-
@st.composite
404-
def array_metadata_v2(draw: st.DrawFn) -> ArrayV2Metadata:
405-
"""Generates valid ArrayV2Metadata objects for property-based testing."""
406-
dims = draw(st.integers(min_value=1, max_value=3)) # Limit dimensions for complexity
407-
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
408-
max_chunk_len = max(shape) if shape else 100
409-
chunks = tuple(
410-
draw(
411-
st.lists(
412-
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
413-
)
414-
)
415-
)
416-
417-
# Validate shape and chunks relationship
418-
assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) # Chunk size must be <= shape
419-
420-
dtype = draw(v2_dtypes())
421-
fill_value = draw(st.one_of([st.none(), npst.from_dtype(dtype)]))
422-
order = draw(st.sampled_from(["C", "F"]))
423-
dimension_separator = draw(st.sampled_from([".", "/"]))
424-
# compressor = draw(zarr_compressors())
425-
filters = tuple(draw(zarr_filters())) if draw(st.booleans()) else None
426-
attributes = draw(simple_attrs())
427-
428-
# Construct the metadata object. Type hints are crucial here for correctness.
429-
return ArrayV2Metadata(
430-
shape=shape,
431-
dtype=dtype,
432-
chunks=chunks,
433-
fill_value=fill_value,
434-
order=order,
435-
dimension_separator=dimension_separator,
436-
# compressor=compressor,
437-
filters=filters,
438-
attributes=attributes,
439-
)
440-
441-
442-
@st.composite
443-
def array_metadata_v3(draw: st.DrawFn) -> ArrayV3Metadata:
444-
"""Generates valid ArrayV3Metadata objects for property-based testing."""
445-
dims = draw(st.integers(min_value=1, max_value=3))
446-
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
447-
max_chunk_len = max(shape) if shape else 100
448-
chunks = tuple(
449-
draw(
450-
st.lists(
451-
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
452-
)
453-
)
454-
)
455-
assume(all(c <= s for s, c in zip(shape, chunks, strict=False)))
456-
457-
dtype = draw(v3_dtypes())
458-
fill_value = draw(npst.from_dtype(dtype))
459-
chunk_grid = RegularChunkGrid(chunks) # Ensure chunks is passed as tuple.
460-
chunk_key_encoding = DefaultChunkKeyEncoding(separator="/") # Or st.sampled_from(["/", "."])
461-
# codecs = tuple(draw(st.lists(zarr_codecs(), min_size=0, max_size=3)))
462-
attributes = draw(simple_attrs())
463-
dimension_names = (
464-
tuple(draw(st.lists(st.one_of(st.none(), simple_text()), min_size=dims, max_size=dims)))
465-
if draw(st.booleans())
466-
else None
467-
)
468-
storage_transformers = tuple(draw(zarr_storage_transformers()))
469-
470-
return ArrayV3Metadata(
471-
shape=shape,
472-
data_type=dtype,
473-
chunk_grid=chunk_grid,
474-
chunk_key_encoding=chunk_key_encoding,
475-
fill_value=fill_value,
476-
# codecs=codecs,
477-
attributes=attributes,
478-
dimension_names=dimension_names,
479-
storage_transformers=storage_transformers,
480-
)

tests/test_properties.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
from hypothesis import assume, given
1515

1616
from zarr.abc.store import Store
17-
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON
17+
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
1818
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
1919
from zarr.core.sync import sync
2020
from zarr.testing.strategies import (
2121
array_metadata,
22-
array_metadata_v2,
2322
arrays,
2423
basic_indices,
2524
numpy_arrays,
@@ -84,7 +83,7 @@ def deep_equal(a, b):
8483

8584

8685
@given(data=st.data(), zarr_format=zarr_formats)
87-
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
86+
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
8887
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
8988
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
9089
assert_array_equal(nparray, zarray[:])
@@ -197,16 +196,20 @@ async def test_roundtrip_array_metadata(
197196
# assert_array_equal(nparray, zarray[:])
198197

199198

200-
@given(array_metadata_v2())
201-
def test_v2meta_roundtrip(metadata):
199+
@given(data=st.data(), zarr_format=zarr_formats)
200+
def test_meta_roundtrip(data: st.DataObject, zarr_format: int) -> None:
201+
metadata = data.draw(array_metadata(zarr_formats=st.just(zarr_format)))
202202
buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype())
203-
zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode())
204-
zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode())
205203

206-
# zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict`
207-
zarray_dict["attributes"] = zattrs_dict
208-
209-
metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict)
204+
if zarr_format == 2:
205+
zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode())
206+
zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode())
207+
# zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict`
208+
zarray_dict["attributes"] = zattrs_dict
209+
metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict)
210+
else:
211+
zarray_dict = json.loads(buffer_dict[ZARR_JSON].to_bytes().decode())
212+
metadata_roundtripped = ArrayV3Metadata.from_dict(zarray_dict)
210213

211214
# Convert both metadata instances to dictionaries.
212215
orig = dataclasses.asdict(metadata)

0 commit comments

Comments
 (0)