Skip to content

Commit dfd31fa

Browse files
committed
Add missing conditional node support
Fixes issue #55
1 parent 174a6b8 commit dfd31fa

File tree

5 files changed

+71
-40
lines changed

5 files changed

+71
-40
lines changed

cuda/_lib/ccudart/utils.pyx.in

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3410,6 +3410,15 @@ cdef cudaError_t toDriverGraphNodeParams(const cudaGraphNodeParams *rtParams, cc
34103410
elif rtParams[0].type == cudaGraphNodeType.cudaGraphNodeTypeMemFree:
34113411
driverParams[0].type = ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_MEM_FREE
34123412
driverParams[0].free.dptr = <ccuda.CUdeviceptr>rtParams[0].free.dptr
3413+
elif rtParams[0].type == cudaGraphNodeType.cudaGraphNodeTypeConditional:
3414+
driverParams[0].type = ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_CONDITIONAL
3415+
# RT params mirror the driver params except the RT struct lacks the ctx at the end.
3416+
memcpy(&driverParams[0].conditional, &rtParams[0].conditional, sizeof(rtParams[0].conditional))
3417+
err = <cudaError_t>ccuda._cuCtxGetCurrent(&context)
3418+
if err != cudaSuccess:
3419+
_setLastError(err)
3420+
return err
3421+
driverParams[0].conditional.ctx = context
34133422
else:
34143423
return cudaErrorInvalidValue
34153424
return cudaSuccess
@@ -3418,6 +3427,8 @@ cdef cudaError_t toDriverGraphNodeParams(const cudaGraphNodeParams *rtParams, cc
34183427
cdef void toCudartGraphNodeOutParams(const ccuda.CUgraphNodeParams *driverParams, cudaGraphNodeParams *rtParams) nogil:
34193428
if driverParams[0].type == ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_MEM_ALLOC:
34203429
rtParams[0].alloc.dptr = <void *>driverParams[0].alloc.dptr
3430+
elif driverParams[0].type == ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_CONDITIONAL:
3431+
rtParams[0].conditional.phGraph_out = <cudaGraph_t *>driverParams[0].conditional.phGraph_out
34213432

34223433

34233434
cdef cudaError_t toDriverKernelNodeParams(const cudaKernelNodeParams nodeParams[0], ccuda.CUDA_KERNEL_NODE_PARAMS *driverNodeParams) except ?cudaErrorCallRequiresNewerDriver nogil:

cuda/cuda.pyx.in

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8441,8 +8441,7 @@ cdef class CUDA_CONDITIONAL_NODE_PARAMS:
84418441
self._handle = CUgraphConditionalHandle(_ptr=<void_ptr>&self._ptr[0].handle)
84428442
self._ctx = CUcontext(_ptr=<void_ptr>&self._ptr[0].ctx)
84438443
def __dealloc__(self):
8444-
if self._phGraph_out is not NULL:
8445-
free(self._phGraph_out)
8444+
pass
84468445
def getPtr(self):
84478446
return <void_ptr>self._ptr
84488447
def __repr__(self):
@@ -8501,25 +8500,8 @@ cdef class CUDA_CONDITIONAL_NODE_PARAMS:
85018500
self._ptr[0].size = size
85028501
@property
85038502
def phGraph_out(self):
8504-
arrs = [<void_ptr>self._ptr[0].phGraph_out + x*sizeof(ccuda.CUgraph) for x in range(self._phGraph_out_length)]
8503+
arrs = [<void_ptr>self._ptr[0].phGraph_out + x*sizeof(ccuda.CUgraph) for x in range(self.size)]
85058504
return [CUgraph(_ptr=arr) for arr in arrs]
8506-
@phGraph_out.setter
8507-
def phGraph_out(self, val):
8508-
if len(val) == 0:
8509-
free(self._phGraph_out)
8510-
self._phGraph_out_length = 0
8511-
self._ptr[0].phGraph_out = NULL
8512-
else:
8513-
if self._phGraph_out_length != <size_t>len(val):
8514-
free(self._phGraph_out)
8515-
self._phGraph_out = <ccuda.CUgraph*> calloc(len(val), sizeof(ccuda.CUgraph))
8516-
if self._phGraph_out is NULL:
8517-
raise MemoryError('Failed to allocate length x size memory: ' + str(len(val)) + 'x' + str(sizeof(ccuda.CUgraph)))
8518-
self._phGraph_out_length = <size_t>len(val)
8519-
self._ptr[0].phGraph_out = self._phGraph_out
8520-
for idx in range(len(val)):
8521-
self._phGraph_out[idx] = (<CUgraph>val[idx])._ptr[0]
8522-
85238505
@property
85248506
def ctx(self):
85258507
return self._ctx

cuda/cudart.pyx.in

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10878,8 +10878,7 @@ cdef class cudaConditionalNodeParams:
1087810878
def __init__(self, void_ptr _ptr = 0):
1087910879
self._handle = cudaGraphConditionalHandle(_ptr=<void_ptr>&self._ptr[0].handle)
1088010880
def __dealloc__(self):
10881-
if self._phGraph_out is not NULL:
10882-
free(self._phGraph_out)
10881+
pass
1088310882
def getPtr(self):
1088410883
return <void_ptr>self._ptr
1088510884
def __repr__(self):
@@ -10934,25 +10933,8 @@ cdef class cudaConditionalNodeParams:
1093410933
self._ptr[0].size = size
1093510934
@property
1093610935
def phGraph_out(self):
10937-
arrs = [<void_ptr>self._ptr[0].phGraph_out + x*sizeof(ccudart.cudaGraph_t) for x in range(self._phGraph_out_length)]
10936+
arrs = [<void_ptr>self._ptr[0].phGraph_out + x*sizeof(ccudart.cudaGraph_t) for x in range(self.size)]
1093810937
return [cudaGraph_t(_ptr=arr) for arr in arrs]
10939-
@phGraph_out.setter
10940-
def phGraph_out(self, val):
10941-
if len(val) == 0:
10942-
free(self._phGraph_out)
10943-
self._phGraph_out_length = 0
10944-
self._ptr[0].phGraph_out = NULL
10945-
else:
10946-
if self._phGraph_out_length != <size_t>len(val):
10947-
free(self._phGraph_out)
10948-
self._phGraph_out = <ccudart.cudaGraph_t*> calloc(len(val), sizeof(ccudart.cudaGraph_t))
10949-
if self._phGraph_out is NULL:
10950-
raise MemoryError('Failed to allocate length x size memory: ' + str(len(val)) + 'x' + str(sizeof(ccudart.cudaGraph_t)))
10951-
self._phGraph_out_length = <size_t>len(val)
10952-
self._ptr[0].phGraph_out = self._phGraph_out
10953-
for idx in range(len(val)):
10954-
self._phGraph_out[idx] = (<cudaGraph_t>val[idx])._ptr[0]
10955-
1095610938
{{endif}}
1095710939
{{if 'struct cudaChildGraphNodeParams' in found_types}}
1095810940

cuda/tests/test_cuda.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,3 +817,33 @@ def test_graph_poly():
817817
assert(err == cuda.CUresult.CUDA_SUCCESS)
818818
err, = cuda.cuCtxDestroy(ctx)
819819
assert(err == cuda.CUresult.CUDA_SUCCESS)
820+
821+
@pytest.mark.skipif(driverVersionLessThan(12030)
822+
or not supportsCudaAPI('cuGraphConditionalHandleCreate'), reason='Conditional graph APIs required')
823+
def test_conditional():
824+
err, = cuda.cuInit(0)
825+
assert(err == cuda.CUresult.CUDA_SUCCESS)
826+
err, device = cuda.cuDeviceGet(0)
827+
assert(err == cuda.CUresult.CUDA_SUCCESS)
828+
err, ctx = cuda.cuCtxCreate(0, device)
829+
assert(err == cuda.CUresult.CUDA_SUCCESS)
830+
831+
err, graph = cuda.cuGraphCreate(0)
832+
assert(err == cuda.CUresult.CUDA_SUCCESS)
833+
err, handle = cuda.cuGraphConditionalHandleCreate(graph, ctx, 0, 0)
834+
assert(err == cuda.CUresult.CUDA_SUCCESS)
835+
836+
params = cuda.CUgraphNodeParams()
837+
params.type = cuda.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
838+
params.conditional.handle = handle
839+
params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF
840+
params.conditional.size = 1
841+
params.conditional.ctx = ctx
842+
843+
assert(len(params.conditional.phGraph_out) == 1)
844+
assert(int(params.conditional.phGraph_out[0]) == 0)
845+
err, node = cuda.cuGraphAddNode(graph, None, 0, params)
846+
assert(err == cuda.CUresult.CUDA_SUCCESS)
847+
848+
assert(len(params.conditional.phGraph_out) == 1)
849+
assert(int(params.conditional.phGraph_out[0]) != 0)

cuda/tests/test_cudart.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def supportsSparseTexturesDeviceFilter():
3232
err, isSupported = cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrSparseCudaArraySupported, 0)
3333
return isSuccess(err) and isSupported
3434

35+
def supportsCudaAPI(name):
36+
return name in dir(cuda) or dir(cudart)
37+
3538
def test_cudart_memcpy():
3639
# Allocate dev memory
3740
size = 1024 * np.uint8().itemsize
@@ -1275,3 +1278,26 @@ def task_callback_stream(stream, status, userData):
12751278
def test_cudart_func_callback():
12761279
cudart_func_stream_callback(use_host_api=False)
12771280
cudart_func_stream_callback(use_host_api=True)
1281+
1282+
1283+
@pytest.mark.skipif(driverVersionLessThan(12030)
1284+
or not supportsCudaAPI('cudaGraphConditionalHandleCreate'), reason='Conditional graph APIs required')
1285+
def test_cudart_conditional():
1286+
err, graph = cudart.cudaGraphCreate(0)
1287+
assertSuccess(err)
1288+
err, handle = cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)
1289+
assertSuccess(err)
1290+
1291+
params = cudart.cudaGraphNodeParams()
1292+
params.type = cudart.cudaGraphNodeType.cudaGraphNodeTypeConditional
1293+
params.conditional.handle = handle
1294+
params.conditional.type = cudart.cudaGraphConditionalNodeType.cudaGraphCondTypeIf
1295+
params.conditional.size = 1
1296+
1297+
assert(len(params.conditional.phGraph_out) == 1)
1298+
assert(int(params.conditional.phGraph_out[0]) == 0)
1299+
err, node = cudart.cudaGraphAddNode(graph, None, 0, params)
1300+
assertSuccess(err)
1301+
1302+
assert(len(params.conditional.phGraph_out) == 1)
1303+
assert(int(params.conditional.phGraph_out[0]) != 0)

0 commit comments

Comments
 (0)