@@ -817,3 +817,33 @@ def test_graph_poly():
817
817
assert (err == cuda .CUresult .CUDA_SUCCESS )
818
818
err , = cuda .cuCtxDestroy (ctx )
819
819
assert (err == cuda .CUresult .CUDA_SUCCESS )
820
+
821
+ @pytest .mark .skipif (driverVersionLessThan (12030 )
822
+ or not supportsCudaAPI ('cuGraphConditionalHandleCreate' ), reason = 'Conditional graph APIs required' )
823
+ def test_conditional ():
824
+ err , = cuda .cuInit (0 )
825
+ assert (err == cuda .CUresult .CUDA_SUCCESS )
826
+ err , device = cuda .cuDeviceGet (0 )
827
+ assert (err == cuda .CUresult .CUDA_SUCCESS )
828
+ err , ctx = cuda .cuCtxCreate (0 , device )
829
+ assert (err == cuda .CUresult .CUDA_SUCCESS )
830
+
831
+ err , graph = cuda .cuGraphCreate (0 )
832
+ assert (err == cuda .CUresult .CUDA_SUCCESS )
833
+ err , handle = cuda .cuGraphConditionalHandleCreate (graph , ctx , 0 , 0 )
834
+ assert (err == cuda .CUresult .CUDA_SUCCESS )
835
+
836
+ params = cuda .CUgraphNodeParams ()
837
+ params .type = cuda .CUgraphNodeType .CU_GRAPH_NODE_TYPE_CONDITIONAL
838
+ params .conditional .handle = handle
839
+ params .conditional .type = cuda .CUgraphConditionalNodeType .CU_GRAPH_COND_TYPE_IF
840
+ params .conditional .size = 1
841
+ params .conditional .ctx = ctx
842
+
843
+ assert (len (params .conditional .phGraph_out ) == 1 )
844
+ assert (int (params .conditional .phGraph_out [0 ]) == 0 )
845
+ err , node = cuda .cuGraphAddNode (graph , None , 0 , params )
846
+ assert (err == cuda .CUresult .CUDA_SUCCESS )
847
+
848
+ assert (len (params .conditional .phGraph_out ) == 1 )
849
+ assert (int (params .conditional .phGraph_out [0 ]) != 0 )
0 commit comments