Skip to content

Commit 2b70419

Browse files
committed
Better performance for test_concat
By generating base shape that dictate the possible axis values
1 parent 8f8da4a commit 2b70419

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,31 +48,30 @@ def assert_array_ndindex(
4848

4949
@given(
5050
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
51-
_axis=st.none() | st.integers(0, MAX_DIMS - 1),
51+
base_shape=hh.shapes(),
5252
data=st.data(),
5353
)
54-
def test_concat(dtypes, _axis, data):
55-
if _axis is None:
54+
def test_concat(dtypes, base_shape, data):
55+
axis_strat = st.none()
56+
ndim = len(base_shape)
57+
if ndim > 0:
58+
axis_strat |= st.integers(-ndim, ndim - 1)
59+
kw = data.draw(
60+
axis_strat.flatmap(lambda a: hh.specified_kwargs(("axis", a, 0))), label="kw"
61+
)
62+
axis = kw.get("axis", 0)
63+
if axis is None:
64+
_axis = None
5665
shape_strat = hh.shapes()
57-
axis_strat = st.none()
5866
else:
59-
base_shape = data.draw(
60-
hh.shapes(min_dims=_axis + 1).map(
61-
lambda t: t[:_axis] + (None,) + t[_axis + 1 :]
62-
),
63-
label="base shape",
64-
)
67+
_axis = axis if axis >= 0 else len(base_shape) + axis
6568
shape_strat = st.integers(0, MAX_SIDE).map(
6669
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
6770
)
68-
axis_strat = st.sampled_from([_axis, _axis - len(base_shape)])
6971
arrays = []
7072
for i, dtype in enumerate(dtypes, 1):
7173
x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}")
7274
arrays.append(x)
73-
kw = data.draw(
74-
axis_strat.flatmap(lambda a: hh.specified_kwargs(("axis", a, 0))), label="kw"
75-
)
7675

7776
out = xp.concat(arrays, **kw)
7877

@@ -292,7 +291,7 @@ def test_roll(x, data):
292291
else:
293292
axis_strat = st.none()
294293
if x.ndim != 0:
295-
axis_strat = axis_strat | st.integers(-x.ndim, x.ndim - 1)
294+
axis_strat |= st.integers(-x.ndim, x.ndim - 1)
296295
kw_strat = hh.kwargs(axis=axis_strat)
297296
kw = data.draw(kw_strat, label="kw")
298297

0 commit comments

Comments
 (0)