Skip to content

Commit 82df864

Browse files
authored
Merge pull request #497 from rwgk/fix_segfault_char_ptr_to_bytes
cuda.bindings: Fix segfault when converting `char*` `NULL` to `bytes`
2 parents 6897c26 + a78a8aa commit 82df864

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

cuda_bindings/cuda/bindings/driver.pyx.in

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22246,7 +22246,7 @@ def cuGetErrorString(error not None : CUresult):
2224622246
cdef cydriver.CUresult cyerror = error.value
2224722247
cdef const char* pStr = NULL
2224822248
err = cydriver.cuGetErrorString(cyerror, &pStr)
22249-
return (CUresult(err), <bytes>pStr)
22249+
return (CUresult(err), <bytes>pStr if pStr != NULL else None)
2225022250
{{endif}}
2225122251

2225222252
{{if 'cuGetErrorName' in found_functions}}
@@ -22279,7 +22279,7 @@ def cuGetErrorName(error not None : CUresult):
2227922279
cdef cydriver.CUresult cyerror = error.value
2228022280
cdef const char* pStr = NULL
2228122281
err = cydriver.cuGetErrorName(cyerror, &pStr)
22282-
return (CUresult(err), <bytes>pStr)
22282+
return (CUresult(err), <bytes>pStr if pStr != NULL else None)
2228322283
{{endif}}
2228422284

2228522285
{{if 'cuInit' in found_functions}}
@@ -27132,7 +27132,7 @@ def cuKernelGetName(hfunc):
2713227132
cyhfunc = <cydriver.CUkernel><void_ptr>phfunc
2713327133
cdef const char* name = NULL
2713427134
err = cydriver.cuKernelGetName(&name, cyhfunc)
27135-
return (CUresult(err), <bytes>name)
27135+
return (CUresult(err), <bytes>name if name != NULL else None)
2713627136
{{endif}}
2713727137

2713827138
{{if 'cuKernelGetParamInfo' in found_functions}}
@@ -38744,7 +38744,7 @@ def cuFuncGetName(hfunc):
3874438744
cyhfunc = <cydriver.CUfunction><void_ptr>phfunc
3874538745
cdef const char* name = NULL
3874638746
err = cydriver.cuFuncGetName(&name, cyhfunc)
38747-
return (CUresult(err), <bytes>name)
38747+
return (CUresult(err), <bytes>name if name != NULL else None)
3874838748
{{endif}}
3874938749

3875038750
{{if 'cuFuncGetParamInfo' in found_functions}}

cuda_bindings/cuda/bindings/nvrtc.pyx.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def nvrtcGetLoweredName(prog, char* name_expression):
870870
cyprog = <cynvrtc.nvrtcProgram><void_ptr>pprog
871871
cdef const char* lowered_name = NULL
872872
err = cynvrtc.nvrtcGetLoweredName(cyprog, name_expression, &lowered_name)
873-
return (nvrtcResult(err), <bytes>lowered_name)
873+
return (nvrtcResult(err), <bytes>lowered_name if lowered_name != NULL else None)
874874
{{endif}}
875875

876876
{{if 'nvrtcGetPCHHeapSize' in found_functions}}

cuda_bindings/tests/test_cuda.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,3 +948,47 @@ def test_conditional():
948948
def test_CUmemDecompressParams_st():
949949
desc = cuda.CUmemDecompressParams_st()
950950
assert int(desc.dstActBytes) == 0
951+
952+
953+
def test_all_CUresult_codes():
954+
max_code = int(max(cuda.CUresult))
955+
# Smoke test. CUDA_ERROR_UNKNOWN = 999, but intentionally using literal value.
956+
assert max_code >= 999
957+
num_good = 0
958+
for code in range(max_code + 2): # One past max_code
959+
try:
960+
error = cuda.CUresult(code)
961+
except ValueError:
962+
pass # cython-generated enum does not exist for this code
963+
else:
964+
err_name, name = cuda.cuGetErrorName(error)
965+
if err_name == cuda.CUresult.CUDA_SUCCESS:
966+
assert name
967+
err_desc, desc = cuda.cuGetErrorString(error)
968+
assert err_desc == cuda.CUresult.CUDA_SUCCESS
969+
assert desc
970+
num_good += 1
971+
else:
972+
# cython-generated enum exists but is not known to an older driver
973+
# (example: cuda-bindings built with CTK 12.8, driver from CTK 12.0)
974+
assert name is None
975+
assert err_name == cuda.CUresult.CUDA_ERROR_INVALID_VALUE
976+
err_desc, desc = cuda.cuGetErrorString(error)
977+
assert err_desc == cuda.CUresult.CUDA_ERROR_INVALID_VALUE
978+
assert desc is None
979+
# Smoke test: Do we have at least some "good" codes?
980+
# The number will increase over time as new enums are added and support for
981+
# old CTKs is dropped, but it is not critical that this number is updated.
982+
assert num_good >= 76 # CTK 11.0.3_450.51.06
983+
984+
985+
def test_cuKernelGetName_failure():
986+
err, name = cuda.cuKernelGetName(0)
987+
assert err == cuda.CUresult.CUDA_ERROR_INVALID_VALUE
988+
assert name is None
989+
990+
991+
def test_cuFuncGetName_failure():
992+
err, name = cuda.cuFuncGetName(0)
993+
assert err == cuda.CUresult.CUDA_ERROR_INVALID_VALUE
994+
assert name is None

0 commit comments

Comments
 (0)