@@ -64,7 +64,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
6464
6565
6666@wraps (xps .arrays )
67- def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
67+ def arrays_no_scalars (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
6868 """xps.arrays() without the crazy large numbers."""
6969 if isinstance (dtype , SearchStrategy ):
7070 return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
@@ -77,6 +77,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
7777 return xps .arrays (dtype , * args , elements = elements , ** kwargs )
7878
7979
80+ def _f (a , flag ):
81+ return a [()] if a .ndim == 0 and flag else a
82+
83+
84+ @wraps (xps .arrays )
85+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
86+ """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
87+
88+ Is only relevant for numpy: on all other libraries, array[()] is no-op.
89+ """
90+ return builds (_f , arrays_no_scalars (dtype , * args , elements = elements , ** kwargs ), booleans ())
91+
92+
8093_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
8194_sorted_dtypes = [d for category in _dtype_categories for d in category ]
8295
0 commit comments