@@ -57,6 +57,21 @@ def find_driver():
57
57
# Force fail
58
58
_raise_driver_not_found ()
59
59
60
+ # Determine DLL type
61
+ if sys .platform == 'win32' :
62
+ dlloader = ctypes .WinDLL
63
+ dldir = ['\\ windows\\ system32' ]
64
+ dlname = 'nvcuda.dll'
65
+ elif sys .platform == 'darwin' :
66
+ dlloader = ctypes .CDLL
67
+ dldir = ['/usr/local/cuda/lib' ]
68
+ dlname = 'libcuda.dylib'
69
+ else :
70
+ # Assume to be *nix like
71
+ dlloader = ctypes .CDLL
72
+ dldir = ['/usr/lib' , '/usr/lib64' ]
73
+ dlname = 'libcuda.so'
74
+
60
75
if envpath is not None :
61
76
try :
62
77
envpath = os .path .abspath (envpath )
@@ -69,21 +84,6 @@ def find_driver():
69
84
".dll/.dylib or the driver" % envpath )
70
85
candidates = [envpath ]
71
86
else :
72
- # Determine DLL type
73
- if sys .platform == 'win32' :
74
- dlloader = ctypes .WinDLL
75
- dldir = ['\\ windows\\ system32' ]
76
- dlname = 'nvcuda.dll'
77
- elif sys .platform == 'darwin' :
78
- dlloader = ctypes .CDLL
79
- dldir = ['/usr/local/cuda/lib' ]
80
- dlname = 'libcuda.dylib'
81
- else :
82
- # Assume to be *nix like
83
- dlloader = ctypes .CDLL
84
- dldir = ['/usr/lib' , '/usr/lib64' ]
85
- dlname = 'libcuda.so'
86
-
87
87
# First search for the name in the default library path.
88
88
# If that is not found, try the specific path.
89
89
candidates = [dlname ] + [os .path .join (x , dlname ) for x in dldir ]
@@ -143,6 +143,10 @@ def _build_reverse_error_map():
143
143
144
144
ERROR_MAP = _build_reverse_error_map ()
145
145
146
+ MISSING_FUNCTION_ERRMSG = """driver missing function: %s.
147
+ Requires CUDA 5.5 or above.
148
+ """
149
+
146
150
147
151
class Driver (object ):
148
152
"""
@@ -157,6 +161,7 @@ def __new__(cls):
157
161
else :
158
162
obj = object .__new__ (cls )
159
163
obj .lib = find_driver ()
164
+ # Initialize driver
160
165
obj .cuInit (0 )
161
166
cls ._singleton = obj
162
167
return obj
@@ -172,11 +177,8 @@ def __getattr__(self, fname):
172
177
raise AttributeError (fname )
173
178
restype = proto [0 ]
174
179
argtypes = proto [1 :]
175
- try :
176
- # Try newer API
177
- libfn = getattr (self .lib , fname + "_v2" )
178
- except AttributeError :
179
- libfn = getattr (self .lib , fname )
180
+
181
+ libfn = self ._find_api (fname )
180
182
libfn .restype = restype
181
183
libfn .argtypes = argtypes
182
184
@@ -188,6 +190,27 @@ def safe_cuda_api_call(*args):
188
190
setattr (self , fname , safe_cuda_api_call )
189
191
return safe_cuda_api_call
190
192
193
+ def _find_api (self , fname ):
194
+ # Try version 2
195
+ try :
196
+ return getattr (self .lib , fname + "_v2" )
197
+ except AttributeError :
198
+ pass
199
+
200
+ # Try regular
201
+ try :
202
+ return getattr (self .lib , fname )
203
+ except AttributeError :
204
+ pass
205
+
206
+ # Not found.
207
+ # Delay missing function error to use
208
+ def absent_function (* args , ** kws ):
209
+ raise CudaDriverError (MISSING_FUNCTION_ERRMSG % fname )
210
+
211
+ setattr (self , fname , absent_function )
212
+ return absent_function
213
+
191
214
def _check_error (self , fname , retcode ):
192
215
if retcode != enums .CUDA_SUCCESS :
193
216
errname = ERROR_MAP .get (retcode , "UNKNOWN_CUDA_ERROR" )
0 commit comments