Skip to content

Commit 884fff3

Browse files
authored
Add missing nvvmGetErrorString() bindings. (#690)
* Update generated code to include `nvvmGetErrorString` * Add test_get_error_string()
1 parent 9ef1562 commit 884fff3

File tree

8 files changed

+75
-0
lines changed

8 files changed

+75
-0
lines changed

cuda_bindings/cuda/bindings/_internal/nvvm.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from ..cynvvm cimport *
1111
# Wrapper functions
1212
###############################################################################
1313

14+
cdef const char* _nvvmGetErrorString(nvvmResult result) except?NULL nogil
1415
cdef nvvmResult _nvvmVersion(int* major, int* minor) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil
1516
cdef nvvmResult _nvvmIRVersion(int* majorIR, int* minorIR, int* majorDbg, int* minorDbg) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil
1617
cdef nvvmResult _nvvmCreateProgram(nvvmProgram* prog) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cdef extern from "<dlfcn.h>" nogil:
3636
cdef bint __py_nvvm_init = False
3737
cdef void* __cuDriverGetVersion = NULL
3838

39+
cdef void* __nvvmGetErrorString = NULL
3940
cdef void* __nvvmVersion = NULL
4041
cdef void* __nvvmIRVersion = NULL
4142
cdef void* __nvvmCreateProgram = NULL
@@ -82,6 +83,13 @@ cdef int _check_or_init_nvvm() except -1 nogil:
8283
handle = NULL
8384

8485
# Load function
86+
global __nvvmGetErrorString
87+
__nvvmGetErrorString = dlsym(RTLD_DEFAULT, 'nvvmGetErrorString')
88+
if __nvvmGetErrorString == NULL:
89+
if handle == NULL:
90+
handle = load_library(driver_ver)
91+
__nvvmGetErrorString = dlsym(handle, 'nvvmGetErrorString')
92+
8593
global __nvvmVersion
8694
__nvvmVersion = dlsym(RTLD_DEFAULT, 'nvvmVersion')
8795
if __nvvmVersion == NULL:
@@ -181,6 +189,9 @@ cpdef dict _inspect_function_pointers():
181189
_check_or_init_nvvm()
182190
cdef dict data = {}
183191

192+
global __nvvmGetErrorString
193+
data["__nvvmGetErrorString"] = <intptr_t>__nvvmGetErrorString
194+
184195
global __nvvmVersion
185196
data["__nvvmVersion"] = <intptr_t>__nvvmVersion
186197

@@ -232,6 +243,16 @@ cpdef _inspect_function_pointer(str name):
232243
# Wrapper functions
233244
###############################################################################
234245

246+
cdef const char* _nvvmGetErrorString(nvvmResult result) except?NULL nogil:
247+
global __nvvmGetErrorString
248+
_check_or_init_nvvm()
249+
if __nvvmGetErrorString == NULL:
250+
with gil:
251+
raise FunctionNotFoundError("function nvvmGetErrorString is not found")
252+
return (<const char* (*)(nvvmResult) noexcept nogil>__nvvmGetErrorString)(
253+
result)
254+
255+
235256
cdef nvvmResult _nvvmVersion(int* major, int* minor) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil:
236257
global __nvvmVersion
237258
_check_or_init_nvvm()

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
2323
cdef bint __py_nvvm_init = False
2424
cdef void* __cuDriverGetVersion = NULL
2525

26+
cdef void* __nvvmGetErrorString = NULL
2627
cdef void* __nvvmVersion = NULL
2728
cdef void* __nvvmIRVersion = NULL
2829
cdef void* __nvvmCreateProgram = NULL
@@ -62,6 +63,12 @@ cdef int _check_or_init_nvvm() except -1 nogil:
6263
handle = path_finder._load_nvidia_dynamic_library("nvvm").handle
6364

6465
# Load function
66+
global __nvvmGetErrorString
67+
try:
68+
__nvvmGetErrorString = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmGetErrorString')
69+
except:
70+
pass
71+
6572
global __nvvmVersion
6673
try:
6774
__nvvmVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmVersion')
@@ -149,6 +156,9 @@ cpdef dict _inspect_function_pointers():
149156
_check_or_init_nvvm()
150157
cdef dict data = {}
151158

159+
global __nvvmGetErrorString
160+
data["__nvvmGetErrorString"] = <intptr_t>__nvvmGetErrorString
161+
152162
global __nvvmVersion
153163
data["__nvvmVersion"] = <intptr_t>__nvvmVersion
154164

@@ -200,6 +210,16 @@ cpdef _inspect_function_pointer(str name):
200210
# Wrapper functions
201211
###############################################################################
202212

213+
cdef const char* _nvvmGetErrorString(nvvmResult result) except?NULL nogil:
214+
global __nvvmGetErrorString
215+
_check_or_init_nvvm()
216+
if __nvvmGetErrorString == NULL:
217+
with gil:
218+
raise FunctionNotFoundError("function nvvmGetErrorString is not found")
219+
return (<const char* (*)(nvvmResult) noexcept nogil>__nvvmGetErrorString)(
220+
result)
221+
222+
203223
cdef nvvmResult _nvvmVersion(int* major, int* minor) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil:
204224
global __nvvmVersion
205225
_check_or_init_nvvm()

cuda_bindings/cuda/bindings/cynvvm.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ ctypedef void* nvvmProgram 'nvvmProgram'
3333
# Functions
3434
###############################################################################
3535

36+
cdef const char* nvvmGetErrorString(nvvmResult result) except?NULL nogil
3637
cdef nvvmResult nvvmVersion(int* major, int* minor) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil
3738
cdef nvvmResult nvvmIRVersion(int* majorIR, int* minorIR, int* majorDbg, int* minorDbg) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil
3839
cdef nvvmResult nvvmCreateProgram(nvvmProgram* prog) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil

cuda_bindings/cuda/bindings/cynvvm.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ from ._internal cimport nvvm as _nvvm
1111
# Wrapper functions
1212
###############################################################################
1313

14+
cdef const char* nvvmGetErrorString(nvvmResult result) except?NULL nogil:
15+
return _nvvm._nvvmGetErrorString(result)
16+
17+
1418
cdef nvvmResult nvvmVersion(int* major, int* minor) except?_NVVMRESULT_INTERNAL_LOADING_ERROR nogil:
1519
return _nvvm._nvvmVersion(major, minor)
1620

cuda_bindings/cuda/bindings/nvvm.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ ctypedef nvvmResult _Result
2727
# Functions
2828
###############################################################################
2929

30+
cpdef str get_error_string(int result)
3031
cpdef tuple version()
3132
cpdef tuple ir_version()
3233
cpdef intptr_t create_program() except? 0

cuda_bindings/cuda/bindings/nvvm.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ cpdef destroy_program(intptr_t prog):
7373
check_status(status)
7474

7575

76+
cpdef str get_error_string(int result):
77+
"""Get the message string for the given ``nvvmResult`` code.
78+
79+
Args:
80+
result (Result): NVVM API result code.
81+
82+
.. seealso:: `nvvmGetErrorString`
83+
"""
84+
cdef bytes _output_
85+
_output_ = nvvmGetErrorString(<_Result>result)
86+
return _output_.decode()
87+
88+
7689
cpdef tuple version():
7790
"""Get the NVVM version.
7891

cuda_bindings/tests/test_nvvm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,20 @@ def get_program_log(prog):
194194
return buffer.decode(errors="backslashreplace")
195195

196196

197+
def test_get_error_string():
198+
num_success = 0
199+
num_errors = 0
200+
for enum_obj in nvvm.Result:
201+
es = nvvm.get_error_string(enum_obj)
202+
if enum_obj is nvvm.Result.SUCCESS:
203+
num_success += 1
204+
else:
205+
assert es.startswith("NVVM_ERROR")
206+
num_errors += 1
207+
assert num_success == 1
208+
assert num_errors > 1 # smoke check is sufficient
209+
210+
197211
def test_nvvm_version():
198212
ver = nvvm.version()
199213
assert len(ver) == 2

0 commit comments

Comments
 (0)