@@ -46,26 +46,27 @@ def assert_array_ndindex(
46
46
assert out [out_idx ] == x [x_idx ], msg
47
47
48
48
49
- @st .composite
50
- def concat_shapes (draw , shape , axis ):
51
- shape = list (shape )
52
- shape [axis ] = draw (st .integers (1 , MAX_SIDE ))
53
- return tuple (shape )
54
-
55
-
56
49
@given (
57
50
dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
58
- kw = hh .kwargs ( axis = st . none () | st . integers ( - MAX_DIMS , MAX_DIMS - 1 ) ),
51
+ base_shape = hh .shapes ( ),
59
52
data = st .data (),
60
53
)
61
- def test_concat (dtypes , kw , data ):
54
+ def test_concat (dtypes , base_shape , data ):
55
+ axis_strat = st .none ()
56
+ ndim = len (base_shape )
57
+ if ndim > 0 :
58
+ axis_strat |= st .integers (- ndim , ndim - 1 )
59
+ kw = data .draw (
60
+ axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
61
+ )
62
62
axis = kw .get ("axis" , 0 )
63
63
if axis is None :
64
+ _axis = None
64
65
shape_strat = hh .shapes ()
65
66
else :
66
- _axis = axis if axis >= 0 else abs ( axis ) - 1
67
- shape_strat = shared_shapes ( min_dims = _axis + 1 ). flatmap (
68
- lambda s : concat_shapes ( s , axis )
67
+ _axis = axis if axis >= 0 else len ( base_shape ) + axis
68
+ shape_strat = st . integers ( 0 , MAX_SIDE ). map (
69
+ lambda i : base_shape [: _axis ] + ( i ,) + base_shape [ _axis + 1 :]
69
70
)
70
71
arrays = []
71
72
for i , dtype in enumerate (dtypes , 1 ):
@@ -77,18 +78,17 @@ def test_concat(dtypes, kw, data):
77
78
ph .assert_dtype ("concat" , dtypes , out .dtype )
78
79
79
80
shapes = tuple (x .shape for x in arrays )
80
- axis = kw .get ("axis" , 0 )
81
- if axis is None :
81
+ if _axis is None :
82
82
size = sum (math .prod (s ) for s in shapes )
83
83
shape = (size ,)
84
84
else :
85
85
shape = list (shapes [0 ])
86
86
for other_shape in shapes [1 :]:
87
- shape [axis ] += other_shape [axis ]
87
+ shape [_axis ] += other_shape [_axis ]
88
88
shape = tuple (shape )
89
89
ph .assert_result_shape ("concat" , shapes , out .shape , shape , ** kw )
90
90
91
- if axis is None :
91
+ if _axis is None :
92
92
out_indices = (i for i in range (out .size ))
93
93
for x_num , x in enumerate (arrays , 1 ):
94
94
for x_idx in sh .ndindex (x .shape ):
@@ -291,7 +291,7 @@ def test_roll(x, data):
291
291
else :
292
292
axis_strat = st .none ()
293
293
if x .ndim != 0 :
294
- axis_strat = axis_strat | st .integers (- x .ndim , x .ndim - 1 )
294
+ axis_strat |= st .integers (- x .ndim , x .ndim - 1 )
295
295
kw_strat = hh .kwargs (axis = axis_strat )
296
296
kw = data .draw (kw_strat , label = "kw" )
297
297
0 commit comments