@@ -21,6 +21,20 @@ def _run(self, X):
2121 return (1 - X ,)
2222
2323
24+ class Transpose2DCastFP16 (OpRun ):
25+ op_domain = "ai.onnx.contrib"
26+
27+ def _run (self , X ):
28+ return (X .T .to (np .float16 ),)
29+
30+
31+ class Transpose2DCastFP32 (OpRun ):
32+ op_domain = "ai.onnx.contrib"
33+
34+ def _run (self , X ):
35+ return (X .T .to (np .float32 ),)
36+
37+
2438class TestCudaOps (unittest .TestCase ):
2539 @staticmethod
2640 def _create_negpos_test_model (domain = "ai.onnx.contrib" ):
@@ -321,6 +335,62 @@ def test_bigger_rotary_cuda(self):
321335 self ._rotary_cuda (TensorProto .FLOAT16 , "left" , input_shape = sh )
322336 self ._rotary_cuda (TensorProto .FLOAT16 , "right" , input_shape = sh )
323337
338+ def _transpose_cast_cuda (self , itype ):
339+ dtype = np .float32 if itype == TensorProto .FLOAT else np .float16
340+ itype2 = TensorProto .FLOAT if itype == TensorProto .FLOAT16 else TensorProto .FLOAT16
341+ model1 = helper .make_model (
342+ helper .make_graph (
343+ [
344+ helper .make_node ("Transpose" , ["X" ], ["t" ], perm = [1 , 0 ]),
345+ helper .make_node ("Cast" , ["t" ], ["Y" ], to = itype2 ),
346+ ],
347+ "nd" ,
348+ [helper .make_tensor_value_info ("X" , itype , [None , None ])],
349+ [helper .make_tensor_value_info ("Y" , itype2 , [None , None ])],
350+ ),
351+ opset_imports = [helper .make_opsetid ("" , 18 )],
352+ ir_version = 9 ,
353+ )
354+
355+ model2 = helper .make_model (
356+ helper .make_graph (
357+ [
358+ helper .make_node (
359+ ("Transpose2DCastFP16" if itype2 == TensorProto .FLOAT16 else "Transpose2DCastFP32" ),
360+ ["X" ],
361+ ["Y" ],
362+ domain = "ai.onnx.contrib" ,
363+ )
364+ ],
365+ "nd" ,
366+ [helper .make_tensor_value_info ("X" , itype , [None , None ])],
367+ [helper .make_tensor_value_info ("Y" , itype2 , [None , None ])],
368+ ),
369+ opset_imports = [
370+ helper .make_opsetid ("" , 18 ),
371+ helper .make_opsetid ("ai.onnx.contrib" , 1 ),
372+ ],
373+ ir_version = 9 ,
374+ )
375+
376+ dtype = np .float32 if itype == TensorProto .FLOAT else np .float16
377+ x = (np .arange (32 * 32 * 3 ) + 1 ).reshape ((32 , 32 * 3 )).astype (dtype )
378+
379+ feeds1 = dict (X = x )
380+ ref = ReferenceEvaluator (model1 , new_ops = [Transpose2DCastFP16 , Transpose2DCastFP32 ])
381+ expected = ref .run (None , feeds1 )[0 ]
382+
383+ opts = _ort .SessionOptions ()
384+ opts .register_custom_ops_library (_get_library_path ())
385+ sess = _ort .InferenceSession (model2 .SerializeToString (), opts , providers = ["CUDAExecutionProvider" ])
386+ got = sess .run (None , feeds1 )[0 ]
387+ assert_almost_equal (expected , got , decimal = 5 )
388+
389+ @unittest .skipIf (not has_cuda (), reason = "cuda not available" )
390+ def test_transpose_cast_cuda (self ):
391+ self ._transpose_cast_cuda (TensorProto .FLOAT )
392+ self ._transpose_cast_cuda (TensorProto .FLOAT16 )
393+
324394
325395if __name__ == "__main__" :
326- unittest .main ()
396+ unittest .main (verbosity = 2 )
0 commit comments