13
13
from cuda .core .experimental ._utils .cuda_utils import (
14
14
ComputeCapability ,
15
15
CUDAError ,
16
+ _check_driver_error ,
16
17
driver ,
17
18
handle_return ,
18
19
precondition ,
@@ -930,6 +931,10 @@ def multicast_supported(self) -> bool:
930
931
return bool (self ._get_cached_attribute (driver .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED ))
931
932
932
933
934
+ _SUCCESS = driver .CUresult .CUDA_SUCCESS
935
+ _INVALID_CTX = driver .CUresult .CUDA_ERROR_INVALID_CONTEXT
936
+
937
+
933
938
class Device :
934
939
"""Represent a GPU and act as an entry point for cuda.core features.
935
940
@@ -959,7 +964,7 @@ class Device:
959
964
960
965
__slots__ = ("_id" , "_mr" , "_has_inited" , "_properties" )
961
966
962
- def __new__ (cls , device_id = None ):
967
+ def __new__ (cls , device_id : Optional [ int ] = None ):
963
968
global _is_cuInit
964
969
if _is_cuInit is False :
965
970
with _lock :
@@ -968,26 +973,34 @@ def __new__(cls, device_id=None):
968
973
969
974
# important: creating a Device instance does not initialize the GPU!
970
975
if device_id is None :
971
- device_id = handle_return (runtime .cudaGetDevice ())
972
- assert_type (device_id , int )
973
- else :
974
- total = handle_return (runtime .cudaGetDeviceCount ())
975
- assert_type (device_id , int )
976
- if not (0 <= device_id < total ):
977
- raise ValueError (f"device_id must be within [0, { total } ), got { device_id } " )
976
+ err , dev = driver .cuCtxGetDevice ()
977
+ if err == _SUCCESS :
978
+ device_id = int (dev )
979
+ elif err == _INVALID_CTX :
980
+ ctx = handle_return (driver .cuCtxGetCurrent ())
981
+ assert int (ctx ) == 0
982
+ device_id = 0 # cudart behavior
983
+ else :
984
+ _check_driver_error (err )
985
+ elif device_id < 0 :
986
+ raise ValueError (f"device_id must be >= 0, got { device_id } " )
978
987
979
988
# ensure Device is singleton
980
- if not hasattr (_tls , "devices" ):
981
- total = handle_return (runtime .cudaGetDeviceCount ())
982
- _tls .devices = []
989
+ try :
990
+ devices = _tls .devices
991
+ except AttributeError :
992
+ total = handle_return (driver .cuDeviceGetCount ())
993
+ devices = _tls .devices = []
983
994
for dev_id in range (total ):
984
995
dev = super ().__new__ (cls )
985
996
dev ._id = dev_id
986
997
# If the device is in TCC mode, or does not support memory pools for some other reason,
987
998
# use the SynchronousMemoryResource which does not use memory pools.
988
999
if (
989
1000
handle_return (
990
- runtime .cudaDeviceGetAttribute (runtime .cudaDeviceAttr .cudaDevAttrMemoryPoolsSupported , 0 )
1001
+ driver .cuDeviceGetAttribute (
1002
+ driver .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , dev_id
1003
+ )
991
1004
)
992
1005
) == 1 :
993
1006
dev ._mr = _DefaultAsyncMempool (dev_id )
@@ -996,9 +1009,12 @@ def __new__(cls, device_id=None):
996
1009
997
1010
dev ._has_inited = False
998
1011
dev ._properties = None
999
- _tls . devices .append (dev )
1012
+ devices .append (dev )
1000
1013
1001
- return _tls .devices [device_id ]
1014
+ try :
1015
+ return devices [device_id ]
1016
+ except IndexError :
1017
+ raise ValueError (f"device_id must be within [0, { len (devices )} ), got { device_id } " ) from None
1002
1018
1003
1019
def _check_context_initialized (self , * args , ** kwargs ):
1004
1020
if not self ._has_inited :
0 commit comments