@@ -48,31 +48,30 @@ def assert_array_ndindex(
48
48
49
49
@given (
50
50
dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
51
- _axis = st . none () | st . integers ( 0 , MAX_DIMS - 1 ),
51
+ base_shape = hh . shapes ( ),
52
52
data = st .data (),
53
53
)
54
- def test_concat (dtypes , _axis , data ):
55
- if _axis is None :
54
+ def test_concat (dtypes , base_shape , data ):
55
+ axis_strat = st .none ()
56
+ ndim = len (base_shape )
57
+ if ndim > 0 :
58
+ axis_strat |= st .integers (- ndim , ndim - 1 )
59
+ kw = data .draw (
60
+ axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
61
+ )
62
+ axis = kw .get ("axis" , 0 )
63
+ if axis is None :
64
+ _axis = None
56
65
shape_strat = hh .shapes ()
57
- axis_strat = st .none ()
58
66
else :
59
- base_shape = data .draw (
60
- hh .shapes (min_dims = _axis + 1 ).map (
61
- lambda t : t [:_axis ] + (None ,) + t [_axis + 1 :]
62
- ),
63
- label = "base shape" ,
64
- )
67
+ _axis = axis if axis >= 0 else len (base_shape ) + axis
65
68
shape_strat = st .integers (0 , MAX_SIDE ).map (
66
69
lambda i : base_shape [:_axis ] + (i ,) + base_shape [_axis + 1 :]
67
70
)
68
- axis_strat = st .sampled_from ([_axis , _axis - len (base_shape )])
69
71
arrays = []
70
72
for i , dtype in enumerate (dtypes , 1 ):
71
73
x = data .draw (xps .arrays (dtype = dtype , shape = shape_strat ), label = f"x{ i } " )
72
74
arrays .append (x )
73
- kw = data .draw (
74
- axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
75
- )
76
75
77
76
out = xp .concat (arrays , ** kw )
78
77
@@ -292,7 +291,7 @@ def test_roll(x, data):
292
291
else :
293
292
axis_strat = st .none ()
294
293
if x .ndim != 0 :
295
- axis_strat = axis_strat | st .integers (- x .ndim , x .ndim - 1 )
294
+ axis_strat |= st .integers (- x .ndim , x .ndim - 1 )
296
295
kw_strat = hh .kwargs (axis = axis_strat )
297
296
kw = data .draw (kw_strat , label = "kw" )
298
297
0 commit comments