@@ -64,7 +64,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
64
64
65
65
66
66
@wraps (xps .arrays )
67
- def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
67
+ def arrays_no_scalars (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
68
68
"""xps.arrays() without the crazy large numbers."""
69
69
if isinstance (dtype , SearchStrategy ):
70
70
return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
@@ -77,6 +77,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
77
77
return xps .arrays (dtype , * args , elements = elements , ** kwargs )
78
78
79
79
80
+ def _f (a , flag ):
81
+ return a [()] if a .ndim == 0 and flag else a
82
+
83
+
84
+ @wraps (xps .arrays )
85
+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
86
+ """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
87
+
88
+ Is only relevant for numpy: on all other libraries, array[()] is no-op.
89
+ """
90
+ return builds (_f , arrays_no_scalars (dtype , * args , elements = elements , ** kwargs ), booleans ())
91
+
92
+
80
93
_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
81
94
_sorted_dtypes = [d for category in _dtype_categories for d in category ]
82
95
0 commit comments