Skip to content

Commit 416d2c2

Browse files
committed
Update stateful/property tests.
Add actions to 1. overwrite data with oindex 2. read and compare a full array
1 parent f68bf06 commit 416d2c2

File tree

3 files changed

+70
-41
lines changed

3 files changed

+70
-41
lines changed

src/zarr/testing/stateful.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
from zarr.testing.strategies import (
2525
basic_indices,
2626
chunk_paths,
27+
dimension_names,
2728
key_ranges,
2829
node_names,
2930
np_array_and_chunks,
30-
numpy_arrays,
31+
orthogonal_indices,
3132
)
3233
from zarr.testing.strategies import keys as zarr_keys
3334

@@ -90,11 +91,7 @@ def add_group(self, name: str, data: DataObject) -> None:
9091
zarr.group(store=self.store, path=path)
9192
zarr.group(store=self.model, path=path)
9293

93-
@rule(
94-
data=st.data(),
95-
name=node_names,
96-
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
97-
)
94+
@rule(data=st.data(), name=node_names, array_and_chunks=np_array_and_chunks())
9895
def add_array(
9996
self,
10097
data: DataObject,
@@ -122,6 +119,10 @@ def add_array(
122119
path=path,
123120
store=store,
124121
fill_value=fill_value,
122+
zarr_format=3,
123+
dimension_names=data.draw(
124+
dimension_names(ndim=array.ndim), label="dimension names"
125+
),
125126
# Chose bytes codec to avoid wasting time compressing the data being written
126127
codecs=[BytesCodec()],
127128
)
@@ -192,6 +193,14 @@ def delete_chunk(self, data: DataObject) -> None:
192193
self._sync(self.model.delete(path))
193194
self._sync(self.store.delete(path))
194195

196+
@precondition(lambda self: bool(self.all_arrays))
197+
@rule(data=st.data())
198+
def check_array(self, data: DataObject) -> None:
199+
path = data.draw(st.sampled_from(sorted(self.all_arrays)))
200+
actual = zarr.open_array(self.store, path=path)[:]
201+
expected = zarr.open_array(self.model, path=path)[:]
202+
np.testing.assert_equal(actual, expected)
203+
195204
@precondition(lambda self: bool(self.all_arrays))
196205
@rule(data=st.data())
197206
def overwrite_array_basic_indexing(self, data: DataObject) -> None:
@@ -206,6 +215,20 @@ def overwrite_array_basic_indexing(self, data: DataObject) -> None:
206215
model_array[slicer] = new_data
207216
store_array[slicer] = new_data
208217

218+
@precondition(lambda self: bool(self.all_arrays))
219+
@rule(data=st.data())
220+
def overwrite_array_orthogonal_indexing(self, data: DataObject) -> None:
221+
array = data.draw(st.sampled_from(sorted(self.all_arrays)))
222+
model_array = zarr.open_array(path=array, store=self.model)
223+
store_array = zarr.open_array(path=array, store=self.store)
224+
indexer, _ = data.draw(orthogonal_indices(shape=model_array.shape))
225+
note(f"overwriting array orthogonal {indexer=}")
226+
new_data = data.draw(
227+
npst.arrays(shape=model_array.oindex[indexer].shape, dtype=model_array.dtype) # type: ignore[union-attr]
228+
)
229+
model_array.oindex[indexer] = new_data
230+
store_array.oindex[indexer] = new_data
231+
209232
@precondition(lambda self: bool(self.all_arrays))
210233
@rule(data=st.data())
211234
def resize_array(self, data: DataObject) -> None:

src/zarr/testing/strategies.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> str:
4343
return draw(st.just("/") | keys(max_num_nodes=max_num_nodes))
4444

4545

46-
def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
46+
def dtypes() -> st.SearchStrategy[np.dtype[Any]]:
4747
return (
4848
npst.boolean_dtypes()
4949
| npst.integer_dtypes(endianness="=")
@@ -57,18 +57,12 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
5757
)
5858

5959

60+
def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
61+
return dtypes()
62+
63+
6064
def v2_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
61-
return (
62-
npst.boolean_dtypes()
63-
| npst.integer_dtypes(endianness="=")
64-
| npst.unsigned_integer_dtypes(endianness="=")
65-
| npst.floating_dtypes(endianness="=")
66-
| npst.complex_number_dtypes(endianness="=")
67-
| npst.byte_string_dtypes(endianness="=")
68-
| npst.unicode_string_dtypes(endianness="=")
69-
| npst.datetime64_dtypes(endianness="=")
70-
| npst.timedelta64_dtypes(endianness="=")
71-
)
65+
return dtypes()
7266

7367

7468
def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
@@ -144,7 +138,7 @@ def array_metadata(
144138
shape = draw(array_shapes())
145139
ndim = len(shape)
146140
chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim))
147-
np_dtype = draw(v3_dtypes())
141+
np_dtype = draw(dtypes())
148142
dtype = get_data_type_from_native_dtype(np_dtype)
149143
fill_value = draw(npst.from_dtype(np_dtype))
150144
if zarr_format == 2:
@@ -179,14 +173,12 @@ def numpy_arrays(
179173
*,
180174
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
181175
dtype: np.dtype[Any] | None = None,
182-
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
183176
) -> npt.NDArray[Any]:
184177
"""
185178
Generate numpy arrays that can be saved in the provided Zarr format.
186179
"""
187-
zarr_format = draw(zarr_formats)
188180
if dtype is None:
189-
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
181+
dtype = draw(dtypes())
190182
if np.issubdtype(dtype, np.str_):
191183
safe_unicode_strings = safe_unicode_for_dtype(dtype)
192184
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))
@@ -255,17 +247,24 @@ def arrays(
255247
attrs: st.SearchStrategy = attrs,
256248
zarr_formats: st.SearchStrategy = zarr_formats,
257249
) -> Array:
258-
store = draw(stores)
259-
path = draw(paths)
260-
name = draw(array_names)
261-
attributes = draw(attrs)
262-
zarr_format = draw(zarr_formats)
250+
store = draw(stores, label="store")
251+
path = draw(paths, label="array parent")
252+
name = draw(array_names, label="array name")
253+
attributes = draw(attrs, label="attributes")
254+
zarr_format = draw(zarr_formats, label="zarr format")
263255
if arrays is None:
264-
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
265-
nparray = draw(arrays)
266-
chunk_shape = draw(chunk_shapes(shape=nparray.shape))
256+
arrays = numpy_arrays(shapes=shapes)
257+
nparray = draw(arrays, label="array data")
258+
chunk_shape = draw(chunk_shapes(shape=nparray.shape), label="chunk shape")
259+
extra_kwargs = {}
267260
if zarr_format == 3 and all(c > 0 for c in chunk_shape):
268-
shard_shape = draw(st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape))
261+
shard_shape = draw(
262+
st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape),
263+
label="shard shape",
264+
)
265+
extra_kwargs["dimension_names"] = draw(
266+
dimension_names(ndim=nparray.ndim), label="dimension names"
267+
)
269268
else:
270269
shard_shape = None
271270
# test that None works too.
@@ -286,6 +285,7 @@ def arrays(
286285
attributes=attributes,
287286
# compressor=compressor, # FIXME
288287
fill_value=fill_value,
288+
**extra_kwargs,
289289
)
290290

291291
assert isinstance(a, Array)
@@ -385,13 +385,19 @@ def orthogonal_indices(
385385
npindexer = []
386386
ndim = len(shape)
387387
for axis, size in enumerate(shape):
388-
val = draw(
389-
npst.integer_array_indices(
388+
if size != 0:
389+
strategy = npst.integer_array_indices(
390390
shape=(size,), result_shape=npst.array_shapes(min_side=1, max_side=size, max_dims=1)
391-
)
392-
| basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
393-
.map(lambda x: (x,) if not isinstance(x, tuple) else x) # bare ints, slices
394-
.filter(bool) # skip empty tuple
391+
) | basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
392+
else:
393+
strategy = basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
394+
395+
val = draw(
396+
strategy
397+
# bare ints, slices
398+
.map(lambda x: (x,) if not isinstance(x, tuple) else x)
399+
# skip empty tuple
400+
.filter(bool)
395401
)
396402
(idxr,) = val
397403
if isinstance(idxr, int):

tests/test_properties.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def deep_equal(a: Any, b: Any) -> bool:
7676

7777

7878
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
79-
@given(data=st.data(), zarr_format=zarr_formats)
80-
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
81-
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
82-
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
79+
@given(data=st.data())
80+
def test_array_roundtrip(data: st.DataObject) -> None:
81+
nparray = data.draw(numpy_arrays())
82+
zarray = data.draw(arrays(arrays=st.just(nparray)))
8383
assert_array_equal(nparray, zarray[:])
8484

8585

0 commit comments

Comments
 (0)