Skip to content

Commit 79e4bf7

Browse files
Cythonize away some perf hot spots (#709)
* cythonize event * cythonize context/event/util * Cython 3.0+ supports __del__ for cdef classes * inline precondition to reduce overhead * centralize check_or_create_options * add back error check * cythonize stream + bug fixes * make linter happy * bug fix * Cython mis-compiles Optional types * Modified _get_context_device() helper routine To fix test failures with CTK 11.8 and driver 535.247.01 only attempt to query _ctx_handle if _device_id is None. Ensure that context handle is set in Stream.context property * Verify that stream.context handle is not None in test_stream_context * cache success enums * nit: avoid cdef void * In Event.close set handle to None before raising error --------- Co-authored-by: Oleksandr Pavlyk <[email protected]>
1 parent 5576da6 commit 79e4bf7

File tree

10 files changed

+226
-175
lines changed

10 files changed

+226
-175
lines changed

cuda_core/cuda/core/experimental/_context.py renamed to cuda_core/cuda/core/experimental/_context.pyx

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,21 @@ class ContextOptions:
1313
pass # TODO
1414

1515

16-
class Context:
17-
__slots__ = ("_handle", "_id")
16+
cdef class Context:
1817

19-
def __new__(self, *args, **kwargs):
18+
cdef:
19+
readonly object _handle
20+
int _device_id
21+
22+
def __init__(self, *args, **kwargs):
2023
raise RuntimeError("Context objects cannot be instantiated directly. Please use Device or Stream APIs.")
2124

2225
@classmethod
23-
def _from_ctx(cls, obj, dev_id):
24-
assert_type(obj, driver.CUcontext)
25-
ctx = super().__new__(cls)
26-
ctx._handle = obj
27-
ctx._id = dev_id
26+
def _from_ctx(cls, handle: driver.CUcontext, int device_id):
27+
cdef Context ctx = Context.__new__(Context)
28+
ctx._handle = handle
29+
ctx._device_id = device_id
2830
return ctx
31+
32+
def __eq__(self, other):
33+
return int(self._handle) == int(other._handle)

cuda_core/cuda/core/experimental/_device.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
_check_driver_error,
1818
driver,
1919
handle_return,
20-
precondition,
2120
runtime,
2221
)
2322

@@ -1017,12 +1016,31 @@ def __new__(cls, device_id: Optional[int] = None):
10171016
except IndexError:
10181017
raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id}") from None
10191018

1020-
def _check_context_initialized(self, *args, **kwargs):
1019+
def _check_context_initialized(self):
10211020
if not self._has_inited:
10221021
raise CUDAError(
10231022
f"Device {self._id} is not yet initialized, perhaps you forgot to call .set_current() first?"
10241023
)
10251024

1025+
def _get_current_context(self, check_consistency=False) -> driver.CUcontext:
1026+
err, ctx = driver.cuCtxGetCurrent()
1027+
1028+
# TODO: We want to just call this:
1029+
# _check_driver_error(err)
1030+
# but even the simplest success check causes 50-100 ns. Wait until we cythonize this file...
1031+
if ctx is None:
1032+
_check_driver_error(err)
1033+
1034+
if int(ctx) == 0:
1035+
raise CUDAError("No context is bound to the calling CPU thread.")
1036+
if check_consistency:
1037+
err, dev = driver.cuCtxGetDevice()
1038+
if err != _SUCCESS:
1039+
handle_return((err,))
1040+
if int(dev) != self._id:
1041+
raise CUDAError("Internal error (current device is not equal to Device.device_id)")
1042+
return ctx
1043+
10261044
@property
10271045
def device_id(self) -> int:
10281046
"""Return device ordinal."""
@@ -1083,7 +1101,6 @@ def compute_capability(self) -> ComputeCapability:
10831101
return cc
10841102

10851103
@property
1086-
@precondition(_check_context_initialized)
10871104
def context(self) -> Context:
10881105
"""Return the current :obj:`~_context.Context` associated with this device.
10891106
@@ -1092,9 +1109,8 @@ def context(self) -> Context:
10921109
Device must be initialized.
10931110
10941111
"""
1095-
ctx = handle_return(driver.cuCtxGetCurrent())
1096-
if int(ctx) == 0:
1097-
raise CUDAError("No context is bound to the calling CPU thread.")
1112+
self._check_context_initialized()
1113+
ctx = self._get_current_context(check_consistency=True)
10981114
return Context._from_ctx(ctx, self._id)
10991115

11001116
@property
@@ -1206,8 +1222,7 @@ def create_context(self, options: ContextOptions = None) -> Context:
12061222
"""
12071223
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/189")
12081224

1209-
@precondition(_check_context_initialized)
1210-
def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions = None) -> Stream:
1225+
def create_stream(self, obj: Optional[IsStreamT] = None, options: Optional[StreamOptions] = None) -> Stream:
12111226
"""Create a Stream object.
12121227
12131228
New stream objects can be created in two different ways:
@@ -1235,9 +1250,9 @@ def create_stream(self, obj: Optional[IsStreamT] = None, options: StreamOptions
12351250
Newly created stream object.
12361251
12371252
"""
1238-
return Stream._init(obj=obj, options=options)
1253+
self._check_context_initialized()
1254+
return Stream._init(obj=obj, options=options, device_id=self._id)
12391255

1240-
@precondition(_check_context_initialized)
12411256
def create_event(self, options: Optional[EventOptions] = None) -> Event:
12421257
"""Create an Event object without recording it to a Stream.
12431258
@@ -1256,9 +1271,10 @@ def create_event(self, options: Optional[EventOptions] = None) -> Event:
12561271
Newly created event object.
12571272
12581273
"""
1259-
return Event._init(self._id, self.context._handle, options)
1274+
self._check_context_initialized()
1275+
ctx = self._get_current_context()
1276+
return Event._init(self._id, ctx, options)
12601277

1261-
@precondition(_check_context_initialized)
12621278
def allocate(self, size, stream: Optional[Stream] = None) -> Buffer:
12631279
"""Allocate device memory from a specified stream.
12641280
@@ -1285,11 +1301,11 @@ def allocate(self, size, stream: Optional[Stream] = None) -> Buffer:
12851301
Newly created buffer object.
12861302
12871303
"""
1304+
self._check_context_initialized()
12881305
if stream is None:
12891306
stream = default_stream()
12901307
return self._mr.allocate(size, stream)
12911308

1292-
@precondition(_check_context_initialized)
12931309
def sync(self):
12941310
"""Synchronize the device.
12951311
@@ -1298,9 +1314,9 @@ def sync(self):
12981314
Device must be initialized.
12991315
13001316
"""
1317+
self._check_context_initialized()
13011318
handle_return(runtime.cudaDeviceSynchronize())
13021319

1303-
@precondition(_check_context_initialized)
13041320
def create_graph_builder(self) -> GraphBuilder:
13051321
"""Create a new :obj:`~_graph.GraphBuilder` object.
13061322
@@ -1310,4 +1326,5 @@ def create_graph_builder(self) -> GraphBuilder:
13101326
Newly created graph builder object.
13111327
13121328
"""
1329+
self._check_context_initialized()
13131330
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)

cuda_core/cuda/core/experimental/_event.py renamed to cuda_core/cuda/core/experimental/_event.pyx

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,28 @@
44

55
from __future__ import annotations
66

7-
import weakref
7+
from cuda.core.experimental._utils.cuda_utils cimport (
8+
_check_driver_error as raise_if_driver_error,
9+
check_or_create_options,
10+
)
11+
812
from dataclasses import dataclass
913
from typing import TYPE_CHECKING, Optional
1014

1115
from cuda.core.experimental._context import Context
1216
from cuda.core.experimental._utils.cuda_utils import (
1317
CUDAError,
14-
check_or_create_options,
1518
driver,
1619
handle_return,
1720
)
18-
from cuda.core.experimental._utils.cuda_utils import (
19-
_check_driver_error as raise_if_driver_error,
20-
)
2121

2222
if TYPE_CHECKING:
2323
import cuda.bindings
2424
from cuda.core.experimental._device import Device
2525

2626

2727
@dataclass
28-
class EventOptions:
28+
cdef class EventOptions:
2929
"""Customizable :obj:`~_event.Event` options.
3030
3131
Attributes
@@ -49,7 +49,7 @@ class EventOptions:
4949
support_ipc: Optional[bool] = False
5050

5151

52-
class Event:
52+
cdef class Event:
5353
"""Represent a record at a specific point of execution within a CUDA stream.
5454
5555
Applications can asynchronously record events at any point in
@@ -77,49 +77,46 @@ class Event:
7777
and they should instead be created through a :obj:`~_stream.Stream` object.
7878
7979
"""
80-
81-
class _MembersNeededForFinalize:
82-
__slots__ = ("handle",)
83-
84-
def __init__(self, event_obj, handle):
85-
self.handle = handle
86-
weakref.finalize(event_obj, self.close)
87-
88-
def close(self):
89-
if self.handle is not None:
90-
handle_return(driver.cuEventDestroy(self.handle))
91-
self.handle = None
92-
93-
def __new__(self, *args, **kwargs):
80+
cdef:
81+
object _handle
82+
bint _timing_disabled
83+
bint _busy_waited
84+
int _device_id
85+
object _ctx_handle
86+
87+
def __init__(self, *args, **kwargs):
9488
raise RuntimeError("Event objects cannot be instantiated directly. Please use Stream APIs (record).")
9589

96-
__slots__ = ("__weakref__", "_mnff", "_timing_disabled", "_busy_waited", "_device_id", "_ctx_handle")
97-
9890
@classmethod
99-
def _init(cls, device_id: int, ctx_handle: Context, options: Optional[EventOptions] = None):
100-
self = super().__new__(cls)
101-
self._mnff = Event._MembersNeededForFinalize(self, None)
102-
103-
options = check_or_create_options(EventOptions, options, "Event options")
91+
def _init(cls, device_id: int, ctx_handle: Context, options=None):
92+
cdef Event self = Event.__new__(Event)
93+
cdef EventOptions opts = check_or_create_options(EventOptions, options, "Event options")
10494
flags = 0x0
10595
self._timing_disabled = False
10696
self._busy_waited = False
107-
if not options.enable_timing:
97+
if not opts.enable_timing:
10898
flags |= driver.CUevent_flags.CU_EVENT_DISABLE_TIMING
10999
self._timing_disabled = True
110-
if options.busy_waited_sync:
100+
if opts.busy_waited_sync:
111101
flags |= driver.CUevent_flags.CU_EVENT_BLOCKING_SYNC
112102
self._busy_waited = True
113-
if options.support_ipc:
103+
if opts.support_ipc:
114104
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")
115-
self._mnff.handle = handle_return(driver.cuEventCreate(flags))
105+
err, self._handle = driver.cuEventCreate(flags)
106+
raise_if_driver_error(err)
116107
self._device_id = device_id
117108
self._ctx_handle = ctx_handle
118109
return self
119110

120-
def close(self):
111+
cpdef close(self):
121112
"""Destroy the event."""
122-
self._mnff.close()
113+
if self._handle is not None:
114+
err, = driver.cuEventDestroy(self._handle)
115+
self._handle = None
116+
raise_if_driver_error(err)
117+
118+
def __del__(self):
119+
self.close()
123120

124121
def __isub__(self, other):
125122
return NotImplemented
@@ -129,7 +126,7 @@ def __rsub__(self, other):
129126

130127
def __sub__(self, other):
131128
# return self - other (in milliseconds)
132-
err, timing = driver.cuEventElapsedTime(other.handle, self.handle)
129+
err, timing = driver.cuEventElapsedTime(other.handle, self._handle)
133130
try:
134131
raise_if_driver_error(err)
135132
return timing
@@ -180,12 +177,12 @@ def sync(self):
180177
has been completed.
181178
182179
"""
183-
handle_return(driver.cuEventSynchronize(self._mnff.handle))
180+
handle_return(driver.cuEventSynchronize(self._handle))
184181

185182
@property
186183
def is_done(self) -> bool:
187184
"""Return True if all captured works have been completed, otherwise False."""
188-
(result,) = driver.cuEventQuery(self._mnff.handle)
185+
result, = driver.cuEventQuery(self._handle)
189186
if result == driver.CUresult.CUDA_SUCCESS:
190187
return True
191188
if result == driver.CUresult.CUDA_ERROR_NOT_READY:
@@ -201,7 +198,7 @@ def handle(self) -> cuda.bindings.driver.CUevent:
201198
This handle is a Python object. To get the memory address of the underlying C
202199
handle, call ``int(Event.handle)``.
203200
"""
204-
return self._mnff.handle
201+
return self._handle
205202

206203
@property
207204
def device(self) -> Device:

0 commit comments

Comments
 (0)