@@ -43,7 +43,7 @@ def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> str:
43
43
return draw (st .just ("/" ) | keys (max_num_nodes = max_num_nodes ))
44
44
45
45
46
- def v3_dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
46
+ def dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
47
47
return (
48
48
npst .boolean_dtypes ()
49
49
| npst .integer_dtypes (endianness = "=" )
@@ -57,18 +57,12 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
57
57
)
58
58
59
59
60
+ def v3_dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
61
+ return dtypes ()
62
+
63
+
60
64
def v2_dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
61
- return (
62
- npst .boolean_dtypes ()
63
- | npst .integer_dtypes (endianness = "=" )
64
- | npst .unsigned_integer_dtypes (endianness = "=" )
65
- | npst .floating_dtypes (endianness = "=" )
66
- | npst .complex_number_dtypes (endianness = "=" )
67
- | npst .byte_string_dtypes (endianness = "=" )
68
- | npst .unicode_string_dtypes (endianness = "=" )
69
- | npst .datetime64_dtypes (endianness = "=" )
70
- | npst .timedelta64_dtypes (endianness = "=" )
71
- )
65
+ return dtypes ()
72
66
73
67
74
68
def safe_unicode_for_dtype (dtype : np .dtype [np .str_ ]) -> st .SearchStrategy [str ]:
@@ -144,7 +138,7 @@ def array_metadata(
144
138
shape = draw (array_shapes ())
145
139
ndim = len (shape )
146
140
chunk_shape = draw (array_shapes (min_dims = ndim , max_dims = ndim ))
147
- np_dtype = draw (v3_dtypes ())
141
+ np_dtype = draw (dtypes ())
148
142
dtype = get_data_type_from_native_dtype (np_dtype )
149
143
fill_value = draw (npst .from_dtype (np_dtype ))
150
144
if zarr_format == 2 :
@@ -179,14 +173,12 @@ def numpy_arrays(
179
173
* ,
180
174
shapes : st .SearchStrategy [tuple [int , ...]] = array_shapes ,
181
175
dtype : np .dtype [Any ] | None = None ,
182
- zarr_formats : st .SearchStrategy [ZarrFormat ] = zarr_formats ,
183
176
) -> npt .NDArray [Any ]:
184
177
"""
185
178
Generate numpy arrays that can be saved in the provided Zarr format.
186
179
"""
187
- zarr_format = draw (zarr_formats )
188
180
if dtype is None :
189
- dtype = draw (v3_dtypes () if zarr_format == 3 else v2_dtypes ())
181
+ dtype = draw (dtypes ())
190
182
if np .issubdtype (dtype , np .str_ ):
191
183
safe_unicode_strings = safe_unicode_for_dtype (dtype )
192
184
return draw (npst .arrays (dtype = dtype , shape = shapes , elements = safe_unicode_strings ))
@@ -255,17 +247,24 @@ def arrays(
255
247
attrs : st .SearchStrategy = attrs ,
256
248
zarr_formats : st .SearchStrategy = zarr_formats ,
257
249
) -> Array :
258
- store = draw (stores )
259
- path = draw (paths )
260
- name = draw (array_names )
261
- attributes = draw (attrs )
262
- zarr_format = draw (zarr_formats )
250
+ store = draw (stores , label = "store" )
251
+ path = draw (paths , label = "array parent" )
252
+ name = draw (array_names , label = "array name" )
253
+ attributes = draw (attrs , label = "attributes" )
254
+ zarr_format = draw (zarr_formats , label = "zarr format" )
263
255
if arrays is None :
264
- arrays = numpy_arrays (shapes = shapes , zarr_formats = st .just (zarr_format ))
265
- nparray = draw (arrays )
266
- chunk_shape = draw (chunk_shapes (shape = nparray .shape ))
256
+ arrays = numpy_arrays (shapes = shapes )
257
+ nparray = draw (arrays , label = "array data" )
258
+ chunk_shape = draw (chunk_shapes (shape = nparray .shape ), label = "chunk shape" )
259
+ extra_kwargs = {}
267
260
if zarr_format == 3 and all (c > 0 for c in chunk_shape ):
268
- shard_shape = draw (st .none () | shard_shapes (shape = nparray .shape , chunk_shape = chunk_shape ))
261
+ shard_shape = draw (
262
+ st .none () | shard_shapes (shape = nparray .shape , chunk_shape = chunk_shape ),
263
+ label = "shard shape" ,
264
+ )
265
+ extra_kwargs ["dimension_names" ] = draw (
266
+ dimension_names (ndim = nparray .ndim ), label = "dimension names"
267
+ )
269
268
else :
270
269
shard_shape = None
271
270
# test that None works too.
@@ -286,6 +285,7 @@ def arrays(
286
285
attributes = attributes ,
287
286
# compressor=compressor, # FIXME
288
287
fill_value = fill_value ,
288
+ ** extra_kwargs ,
289
289
)
290
290
291
291
assert isinstance (a , Array )
@@ -385,13 +385,19 @@ def orthogonal_indices(
385
385
npindexer = []
386
386
ndim = len (shape )
387
387
for axis , size in enumerate (shape ):
388
- val = draw (
389
- npst .integer_array_indices (
388
+ if size != 0 :
389
+ strategy = npst .integer_array_indices (
390
390
shape = (size ,), result_shape = npst .array_shapes (min_side = 1 , max_side = size , max_dims = 1 )
391
- )
392
- | basic_indices (min_dims = 1 , shape = (size ,), allow_ellipsis = False )
393
- .map (lambda x : (x ,) if not isinstance (x , tuple ) else x ) # bare ints, slices
394
- .filter (bool ) # skip empty tuple
391
+ ) | basic_indices (min_dims = 1 , shape = (size ,), allow_ellipsis = False )
392
+ else :
393
+ strategy = basic_indices (min_dims = 1 , shape = (size ,), allow_ellipsis = False )
394
+
395
+ val = draw (
396
+ strategy
397
+ # bare ints, slices
398
+ .map (lambda x : (x ,) if not isinstance (x , tuple ) else x )
399
+ # skip empty tuple
400
+ .filter (bool )
395
401
)
396
402
(idxr ,) = val
397
403
if isinstance (idxr , int ):
0 commit comments