Skip to content

Commit 298ba5b

Browse files
authored
Merge pull request #310 from ev-br/array_scalars
ENH: generate numpy scalars or 0D arrays
2 parents 11bb686 + 024956a commit 298ba5b

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

array_api_tests/hypothesis_helpers.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
6464

6565

6666
@wraps(xps.arrays)
67-
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
67+
def arrays_no_scalars(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
6868
"""xps.arrays() without the crazy large numbers."""
6969
if isinstance(dtype, SearchStrategy):
7070
return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
@@ -77,6 +77,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
7777
return xps.arrays(dtype, *args, elements=elements, **kwargs)
7878

7979

80+
def _f(a, flag):
81+
return a[()] if a.ndim==0 and flag else a
82+
83+
84+
@wraps(xps.arrays)
85+
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
86+
"""xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
87+
88+
Is only relevant for numpy: on all other libraries, array[()] is no-op.
89+
"""
90+
return builds(_f, arrays_no_scalars(dtype, *args, elements=elements, **kwargs), booleans())
91+
92+
8093
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
8194
_sorted_dtypes = [d for category in _dtype_categories for d in category]
8295

array_api_tests/test_creation_functions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
263263
data=st.data(),
264264
)
265265
def test_asarray_arrays(shape, dtypes, data):
266-
x = data.draw(hh.arrays(dtype=dtypes.input_dtype, shape=shape), label="x")
266+
# generate arrays only since we draw the copy= kwd below (and np.asarray(scalar, copy=False) error out)
267+
x = data.draw(hh.arrays_no_scalars(dtype=dtypes.input_dtype, shape=shape), label="x")
267268
dtypes_strat = st.just(dtypes.input_dtype)
268269
if dtypes.input_dtype == dtypes.result_dtype:
269270
dtypes_strat |= st.none()

0 commit comments

Comments
 (0)