Skip to content

Commit 41cbd6d

Browse files
committed
Improve cuda driver initialization
1 parent 220b326 commit 41cbd6d

File tree

3 files changed

+54
-21
lines changed

3 files changed

+54
-21
lines changed

numba/cuda/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
from .stubs import (threadIdx, blockIdx, blockDim, gridDim, syncthreads,
55
shared, local, const, grid, atomic)
66
from .cudadrv.error import CudaSupportError
7+
8+
cuda_error = None
9+
710
try:
811
from .decorators import jit, autojit, declare_device
9-
except CudaSupportError:
12+
except CudaSupportError as e:
1013
is_available = False
14+
cuda_error = e
1115
else:
1216
is_available = True
1317
from .api import *

numba/cuda/cudadrv/driver.py

+43-20
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ def find_driver():
5757
# Force fail
5858
_raise_driver_not_found()
5959

60+
# Determine DLL type
61+
if sys.platform == 'win32':
62+
dlloader = ctypes.WinDLL
63+
dldir = ['\\windows\\system32']
64+
dlname = 'nvcuda.dll'
65+
elif sys.platform == 'darwin':
66+
dlloader = ctypes.CDLL
67+
dldir = ['/usr/local/cuda/lib']
68+
dlname = 'libcuda.dylib'
69+
else:
70+
# Assume to be *nix like
71+
dlloader = ctypes.CDLL
72+
dldir = ['/usr/lib', '/usr/lib64']
73+
dlname = 'libcuda.so'
74+
6075
if envpath is not None:
6176
try:
6277
envpath = os.path.abspath(envpath)
@@ -69,21 +84,6 @@ def find_driver():
6984
".dll/.dylib or the driver" % envpath)
7085
candidates = [envpath]
7186
else:
72-
# Determine DLL type
73-
if sys.platform == 'win32':
74-
dlloader = ctypes.WinDLL
75-
dldir = ['\\windows\\system32']
76-
dlname = 'nvcuda.dll'
77-
elif sys.platform == 'darwin':
78-
dlloader = ctypes.CDLL
79-
dldir = ['/usr/local/cuda/lib']
80-
dlname = 'libcuda.dylib'
81-
else:
82-
# Assume to be *nix like
83-
dlloader = ctypes.CDLL
84-
dldir = ['/usr/lib', '/usr/lib64']
85-
dlname = 'libcuda.so'
86-
8787
# First search for the name in the default library path.
8888
# If that is not found, try the specific path.
8989
candidates = [dlname] + [os.path.join(x, dlname) for x in dldir]
@@ -143,6 +143,10 @@ def _build_reverse_error_map():
143143

144144
ERROR_MAP = _build_reverse_error_map()
145145

146+
MISSING_FUNCTION_ERRMSG = """driver missing function: %s.
147+
Requires CUDA 5.5 or above.
148+
"""
149+
146150

147151
class Driver(object):
148152
"""
@@ -157,6 +161,7 @@ def __new__(cls):
157161
else:
158162
obj = object.__new__(cls)
159163
obj.lib = find_driver()
164+
# Initialize driver
160165
obj.cuInit(0)
161166
cls._singleton = obj
162167
return obj
@@ -172,11 +177,8 @@ def __getattr__(self, fname):
172177
raise AttributeError(fname)
173178
restype = proto[0]
174179
argtypes = proto[1:]
175-
try:
176-
# Try newer API
177-
libfn = getattr(self.lib, fname + "_v2")
178-
except AttributeError:
179-
libfn = getattr(self.lib, fname)
180+
181+
libfn = self._find_api(fname)
180182
libfn.restype = restype
181183
libfn.argtypes = argtypes
182184

@@ -188,6 +190,27 @@ def safe_cuda_api_call(*args):
188190
setattr(self, fname, safe_cuda_api_call)
189191
return safe_cuda_api_call
190192

193+
def _find_api(self, fname):
194+
# Try version 2
195+
try:
196+
return getattr(self.lib, fname + "_v2")
197+
except AttributeError:
198+
pass
199+
200+
# Try regular
201+
try:
202+
return getattr(self.lib, fname)
203+
except AttributeError:
204+
pass
205+
206+
# Not found.
207+
# Delay missing function error to use
208+
def absent_function(*args, **kws):
209+
raise CudaDriverError(MISSING_FUNCTION_ERRMSG % fname)
210+
211+
setattr(self, fname, absent_function)
212+
return absent_function
213+
191214
def _check_error(self, fname, retcode):
192215
if retcode != enums.CUDA_SUCCESS:
193216
errname = ERROR_MAP.get(retcode, "UNKNOWN_CUDA_ERROR")

tools/dummy_libcuda/cuda.c

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// This is a dummy cuda driver for testing
2+
3+
int cuInit() {
4+
return 0;
5+
}
6+

0 commit comments

Comments
 (0)