Skip to content

Commit 0fe2309

Browse files
authored
Switch to use CUDA driver APIs in Device constructor (#460)
* cache cc to speed it up * avoid using cudart APIs in Device constructor * avoid silly, redundant lock * minor perf opt: try-except + skip assert * also optimize for explicit dev id * update release notes * debug sanitizer * fix type hint; compare against enums * Revert "debug sanitizer" This reverts commit d279e50.
1 parent 414b124 commit 0fe2309

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from cuda.core.experimental._utils.cuda_utils import (
1414
ComputeCapability,
1515
CUDAError,
16+
_check_driver_error,
1617
driver,
1718
handle_return,
1819
precondition,
@@ -930,6 +931,10 @@ def multicast_supported(self) -> bool:
930931
return bool(self._get_cached_attribute(driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED))
931932

932933

934+
_SUCCESS = driver.CUresult.CUDA_SUCCESS
935+
_INVALID_CTX = driver.CUresult.CUDA_ERROR_INVALID_CONTEXT
936+
937+
933938
class Device:
934939
"""Represent a GPU and act as an entry point for cuda.core features.
935940
@@ -959,7 +964,7 @@ class Device:
959964

960965
__slots__ = ("_id", "_mr", "_has_inited", "_properties")
961966

962-
def __new__(cls, device_id=None):
967+
def __new__(cls, device_id: Optional[int] = None):
963968
global _is_cuInit
964969
if _is_cuInit is False:
965970
with _lock:
@@ -968,26 +973,34 @@ def __new__(cls, device_id=None):
968973

969974
# important: creating a Device instance does not initialize the GPU!
970975
if device_id is None:
971-
device_id = handle_return(runtime.cudaGetDevice())
972-
assert_type(device_id, int)
973-
else:
974-
total = handle_return(runtime.cudaGetDeviceCount())
975-
assert_type(device_id, int)
976-
if not (0 <= device_id < total):
977-
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")
976+
err, dev = driver.cuCtxGetDevice()
977+
if err == _SUCCESS:
978+
device_id = int(dev)
979+
elif err == _INVALID_CTX:
980+
ctx = handle_return(driver.cuCtxGetCurrent())
981+
assert int(ctx) == 0
982+
device_id = 0 # cudart behavior
983+
else:
984+
_check_driver_error(err)
985+
elif device_id < 0:
986+
raise ValueError(f"device_id must be >= 0, got {device_id}")
978987

979988
# ensure Device is singleton
980-
if not hasattr(_tls, "devices"):
981-
total = handle_return(runtime.cudaGetDeviceCount())
982-
_tls.devices = []
989+
try:
990+
devices = _tls.devices
991+
except AttributeError:
992+
total = handle_return(driver.cuDeviceGetCount())
993+
devices = _tls.devices = []
983994
for dev_id in range(total):
984995
dev = super().__new__(cls)
985996
dev._id = dev_id
986997
# If the device is in TCC mode, or does not support memory pools for some other reason,
987998
# use the SynchronousMemoryResource which does not use memory pools.
988999
if (
9891000
handle_return(
990-
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
1001+
driver.cuDeviceGetAttribute(
1002+
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
1003+
)
9911004
)
9921005
) == 1:
9931006
dev._mr = _DefaultAsyncMempool(dev_id)
@@ -996,9 +1009,12 @@ def __new__(cls, device_id=None):
9961009

9971010
dev._has_inited = False
9981011
dev._properties = None
999-
_tls.devices.append(dev)
1012+
devices.append(dev)
10001013

1001-
return _tls.devices[device_id]
1014+
try:
1015+
return devices[device_id]
1016+
except IndexError:
1017+
raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id}") from None
10021018

10031019
def _check_context_initialized(self, *args, **kwargs):
10041020
if not self._has_inited:

cuda_core/docs/source/release/0.3.0-notes.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ New features
2222

2323
- :class:`Kernel` adds :attr:`Kernel.num_arguments` and :attr:`Kernel.arguments_info` for introspection of kernel arguments. (#612)
2424
- Add pythonic access to kernel occupancy calculation functions via :attr:`Kernel.occupancy`. (#648)
25-
- Support launching cooperative kernels by setting :property:`LaunchConfig.cooperative_launch` to `True`.
25+
- Support launching cooperative kernels by setting :attr:`LaunchConfig.cooperative_launch` to `True`.
2626
- A name can be assigned to :class:`ObjectCode` instances generated by both :class:`Program` and :class:`Linker` through their respective
2727
options.
2828

@@ -34,5 +34,6 @@ New examples
3434
Fixes and enhancements
3535
----------------------
3636

37-
- An :class:`Event` can now be used to look up its corresponding device and context using the ``.device`` and ``.context`` attributes respectively.
37+
- Look-up of the :attr:`Event.device` and :attr:`Event.context` (the device and CUDA context where an event was created from) is now possible.
3838
- The :func:`launch` function's handling of fp16 scalars was incorrect and is fixed.
39+
- The :class:`Device` constructor is made faster.

0 commit comments

Comments
 (0)