@@ -151,8 +151,6 @@ def test_cuda_negxplus1(self):
151151 self ._negxplus1_cuda (TensorProto .FLOAT16 )
152152
153153 def _addmul_shared_input_cuda (self , itype , op_type , shapea = (3 , 2 , 3 ), shapeb = (3 , 2 , 3 ), shapec = (3 , 2 , 3 )):
154- from onnx_extended .ortops .optim .cuda import get_ort_ext_libs
155-
156154 model1 = helper .make_model (
157155 helper .make_graph (
158156 [
@@ -181,7 +179,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
181179 f"{ op_type } SharedInput" ,
182180 ["X" , "Y" , "Z" ],
183181 ["XY" , "XZ" ],
184- domain = "onnx_extended.ortops.optim.cuda " ,
182+ domain = "ai.onnx.contrib " ,
185183 )
186184 ],
187185 "nd" ,
@@ -197,7 +195,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
197195 ),
198196 opset_imports = [
199197 helper .make_opsetid ("" , 18 ),
200- helper .make_opsetid ("onnx_extended.ortops.optim.cuda " , 1 ),
198+ helper .make_opsetid ("ai.onnx.contrib " , 1 ),
201199 ],
202200 ir_version = 9 ,
203201 )
@@ -212,7 +210,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
212210 expected = ref .run (None , feeds1 )
213211
214212 opts = _ort .SessionOptions ()
215- opts .register_custom_ops_library (get_ort_ext_libs ()[ 0 ] )
213+ opts .register_custom_ops_library (_get_library_path () )
216214 sess = _ort .InferenceSession (model2 .SerializeToString (), opts , providers = ["CUDAExecutionProvider" ])
217215 got = sess .run (None , feeds1 )
218216 for i in range (2 ):
@@ -262,6 +260,67 @@ def test_add_shared_input_cuda_broadcast2(self):
262260 shapec = (3 , 2 , 3 ),
263261 )
264262
263+ def _rotary_cuda (self , itype , side , input_shape = (3 , 2 , 3 , 4 )):
264+ model2 = helper .make_model (
265+ helper .make_graph (
266+ [
267+ helper .make_node (
268+ "Rotary" ,
269+ ["X" , "splits" ],
270+ ["Y" ],
271+ domain = "ai.onnx.contrib" ,
272+ side = side ,
273+ )
274+ ],
275+ "nd" ,
276+ [
277+ helper .make_tensor_value_info ("X" , itype , [None , None , None , None ]),
278+ helper .make_tensor_value_info ("splits" , TensorProto .INT64 , [2 ]),
279+ ],
280+ [helper .make_tensor_value_info ("Y" , itype , [None , None , None , None ])],
281+ ),
282+ opset_imports = [
283+ helper .make_opsetid ("" , 18 ),
284+ helper .make_opsetid ("ai.onnx.contrib" , 1 ),
285+ ],
286+ ir_version = 9 ,
287+ )
288+
289+ dtype = np .float32 if itype == TensorProto .FLOAT else np .float16
290+ x = (np .arange (np .prod (input_shape )) + 1 ).reshape (input_shape ).astype (dtype )
291+ splits = np .array ([x .shape [- 1 ] // 2 , x .shape [- 1 ] // 2 ], dtype = np .int64 )
292+
293+ expected = x .copy ()
294+ half = x .shape [- 1 ] // 2
295+ if side == "left" :
296+ expected [:, :, :, :half ] = x [:, :, :, half :]
297+ expected [:, :, :, half :] = - x [:, :, :, :half ]
298+ else :
299+ expected [:, :, :, :half ] = - x [:, :, :, half :]
300+ expected [:, :, :, half :] = x [:, :, :, :half ]
301+
302+ feeds = dict (X = x , splits = splits )
303+ opts = _ort .SessionOptions ()
304+ opts .register_custom_ops_library (_get_library_path ())
305+ sess = _ort .InferenceSession (model2 .SerializeToString (), opts , providers = ["CUDAExecutionProvider" ])
306+ got = sess .run (None , feeds )[0 ]
307+ assert_almost_equal (expected , got )
308+
309+ @unittest .skipIf (not has_cuda (), reason = "cuda not available" )
310+ def test_rotary_cuda (self ):
311+ self ._rotary_cuda (TensorProto .FLOAT , "left" )
312+ self ._rotary_cuda (TensorProto .FLOAT , "right" )
313+ self ._rotary_cuda (TensorProto .FLOAT16 , "left" )
314+ self ._rotary_cuda (TensorProto .FLOAT16 , "right" )
315+
316+ @unittest .skipIf (not has_cuda (), reason = "cuda not available" )
317+ def test_bigger_rotary_cuda (self ):
318+ sh = (2 , 2 , 1024 , 8 )
319+ self ._rotary_cuda (TensorProto .FLOAT , "left" , input_shape = sh )
320+ self ._rotary_cuda (TensorProto .FLOAT , "right" , input_shape = sh )
321+ self ._rotary_cuda (TensorProto .FLOAT16 , "left" , input_shape = sh )
322+ self ._rotary_cuda (TensorProto .FLOAT16 , "right" , input_shape = sh )
323+
265324
266325if __name__ == "__main__" :
267326 unittest .main ()
0 commit comments