@@ -148,23 +148,29 @@ def cross_args(draw, dtype_objects=dh.real_dtypes):
148148 in the drawn axis.
149149
150150 """
151- shape = list ( draw (shapes ()) )
152- size = len (shape )
153- assume (size > 0 )
151+ shape1 , shape2 = draw (two_mutually_broadcastable_shapes )
152+ min_ndim = min ( len (shape1 ), len ( shape2 ) )
153+ assume (min_ndim > 0 )
154154
155- kw = draw (kwargs (axis = integers (- size , size - 1 )))
155+ kw = draw (kwargs (axis = integers (- min_ndim , - 1 )))
156156 axis = kw .get ('axis' , - 1 )
157- shape [axis ] = 3
158- shape = tuple (shape )
157+ if draw (booleans ()):
158+ # Sometimes allow invalid inputs to test it errors
159+ shape1 = list (shape1 )
160+ shape1 [axis ] = 3
161+ shape1 = tuple (shape1 )
162+ shape2 = list (shape2 )
163+ shape2 [axis ] = 3
164+ shape2 = tuple (shape2 )
159165
160166 mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtype_objects ))
161167 arrays1 = arrays (
162168 dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
163- shape = shape ,
169+ shape = shape1 ,
164170 )
165171 arrays2 = arrays (
166172 dtype = mutual_dtypes .map (lambda pair : pair [1 ]),
167- shape = shape ,
173+ shape = shape2 ,
168174 )
169175 return draw (arrays1 ), draw (arrays2 ), kw
170176
@@ -176,15 +182,17 @@ def test_cross(x1_x2_kw):
176182 x1 , x2 , kw = x1_x2_kw
177183
178184 axis = kw .get ('axis' , - 1 )
179- err = "test_cross produced invalid input. This indicates a bug in the test suite."
180- assert x1 . shape == x2 . shape , err
181- shape = x1 . shape
182- assert x1 . shape [ axis ] == x2 . shape [ axis ] == 3 , err
185+ if not ( x1 . shape [ axis ] == x2 . shape [ axis ] == 3 ):
186+ ph . raises ( Exception , lambda : xp . cross ( x1 , x2 , ** kw ),
187+ "cross did not raise an exception for invalid shapes" )
188+ return
183189
184190 res = linalg .cross (x1 , x2 , ** kw )
185191
192+ broadcasted_shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
193+
186194 assert res .dtype == dh .result_type (x1 .dtype , x2 .dtype ), "cross() did not return the correct dtype"
187- assert res .shape == shape , "cross() did not return the correct shape"
195+ assert res .shape == broadcasted_shape , "cross() did not return the correct shape"
188196
189197 def exact_cross (a , b ):
190198 assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
0 commit comments