Skip to content

Commit de85b92

Browse files
cuda.core: add GraphBuilder.graph_definition property
Completes step 3 of #1330 by exposing the captured graph as an explicit `GraphDefinition` view that shares ownership of the underlying `CUgraph`. The handle-layer plumbing landed in PR #2008; this commit wires up the user-facing surface and locks in the state-guard rules. State semantics: - PRIMARY builder: only valid after `end_building()`. Before `begin_building()` no graph exists; during capture the driver is the sole writer, so explicit access is unsafe. - CONDITIONAL_BODY builder: valid both before `begin_building()` (the body graph is allocated at conditional-node creation time) and after `end_building()`. This enables a hybrid flow where a conditional body is populated entirely via the explicit API, with no capture at all. - FORKED builder: never valid. Forked builders share the primary's graph; access through the primary instead. Tests cover the happy path, both hybrid flows on conditional bodies (populate-via-explicit-API and capture-then-augment), the three error states (forked, capturing, primary pre-capture), and the shared-ownership guarantee (the `GraphDefinition` survives the builder's `close()`). Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent f550183 commit de85b92

2 files changed

Lines changed: 250 additions & 3 deletions

File tree

cuda_core/cuda/core/graph/_graph_builder.pyx

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ from libc.stdint cimport intptr_t
88

99
from cuda.bindings cimport cydriver
1010

11-
from cuda.core.graph._graph_definition cimport GraphCondition
11+
from cuda.core.graph._graph_definition cimport GraphCondition, GraphDefinition
1212
from cuda.core.graph._utils cimport _attach_host_callback_to_graph
1313
from cuda.core._resource_handles cimport (
1414
as_cu, as_py,
@@ -279,6 +279,64 @@ cdef class GraphBuilder:
279279
"""Returns True if this graph builder must be joined before building is ended."""
280280
return self._kind == FORKED
281281

282+
@property
283+
def graph_definition(self) -> GraphDefinition:
284+
"""The captured graph as an explicit :class:`~graph.GraphDefinition`.
285+
286+
The returned :class:`~graph.GraphDefinition` is a view of the same
287+
graph this builder is producing: nodes added through it appear in
288+
subsequent :meth:`complete` and :meth:`debug_dot_print` calls, and
289+
the view stays valid even after the builder is closed.
290+
291+
This lets you mix the capture and explicit APIs on a single graph,
292+
for example to inspect what was captured, augment it with extra
293+
nodes, or build a conditional body entirely with the explicit API.
294+
295+
Availability:
296+
297+
- **Primary builders** (created by :meth:`Device.create_graph_builder`
298+
or :meth:`Stream.create_graph_builder`): only after
299+
:meth:`end_building`.
300+
301+
- **Conditional-body builders** (returned by :meth:`if_then`,
302+
:meth:`if_else`, :meth:`while_loop`, :meth:`switch`): both before
303+
:meth:`begin_building` and after :meth:`end_building`. The body
304+
graph already exists when the conditional is created, so you may
305+
populate it through this view without ever calling
306+
:meth:`begin_building` on the body builder.
307+
308+
- **Forked builders** (returned by :meth:`split`): never. Forked
309+
builders share the primary builder's graph; access it through the
310+
primary instead.
311+
312+
Returns
313+
-------
314+
GraphDefinition
315+
A view of the graph being built.
316+
317+
Raises
318+
------
319+
RuntimeError
320+
If the builder is forked, currently building, or (for primary
321+
builders) has not started building yet.
322+
"""
323+
if self._kind == FORKED:
324+
raise RuntimeError(
325+
"graph_definition is unavailable on forked graph builders; "
326+
"access it through the primary builder instead."
327+
)
328+
if self._state == CAPTURING:
329+
raise RuntimeError(
330+
"graph_definition is unavailable while capture is in "
331+
"progress; call end_building() first."
332+
)
333+
if self._kind == PRIMARY and self._state == CAPTURE_NOT_STARTED:
334+
raise RuntimeError(
335+
"graph_definition is unavailable before begin_building() on "
336+
"a primary builder; no graph has been created yet."
337+
)
338+
return GraphDefinition._from_handle(self._h_graph)
339+
282340
def begin_building(self, mode="relaxed") -> GraphBuilder:
283341
"""Begins the building process.
284342

cuda_core/tests/graph/test_graph_builder.py

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""GraphBuilder stream capture tests."""
55

66
import numpy as np
77
import pytest
8-
from helpers.graph_kernels import compile_common_kernels
8+
from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels
99
from helpers.marks import requires_module
10+
from helpers.misc import try_create_condition
1011

1112
from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch
13+
from cuda.core.graph import GraphDefinition
1214

1315

1416
def test_graph_is_building(init_cuda):
@@ -288,3 +290,190 @@ def test_graph_stream_lifetime(init_cuda):
288290

289291
# Destroy the stream
290292
stream.close()
293+
294+
295+
# ---------------------------------------------------------------------------
296+
# GraphBuilder.graph_definition
297+
# ---------------------------------------------------------------------------
298+
299+
300+
def test_graph_definition_returns_graph_definition_after_end_building(init_cuda):
301+
"""Primary builder exposes its captured graph as a GraphDefinition after end_building()."""
302+
mod = compile_common_kernels()
303+
empty_kernel = mod.get_kernel("empty_kernel")
304+
305+
gb = Device().create_graph_builder().begin_building()
306+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
307+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
308+
gb.end_building()
309+
310+
gd = gb.graph_definition
311+
assert isinstance(gd, GraphDefinition)
312+
# The captured graph must contain the launched kernels.
313+
assert len(gd.nodes()) == 2
314+
315+
316+
def test_graph_definition_raises_before_begin_building(init_cuda):
317+
"""Primary builder has no graph allocated before begin_building()."""
318+
gb = Device().create_graph_builder()
319+
with pytest.raises(RuntimeError, match="before begin_building"):
320+
_ = gb.graph_definition
321+
322+
323+
def test_graph_definition_raises_during_capture(init_cuda):
324+
"""graph_definition is unsafe while the driver is actively capturing."""
325+
gb = Device().create_graph_builder().begin_building()
326+
try:
327+
with pytest.raises(RuntimeError, match="capture is in"):
328+
_ = gb.graph_definition
329+
finally:
330+
gb.end_building()
331+
332+
333+
def test_graph_definition_raises_for_forked(init_cuda):
334+
"""Forked builders share the primary's graph; their property must raise."""
335+
mod = compile_common_kernels()
336+
empty_kernel = mod.get_kernel("empty_kernel")
337+
338+
gb = Device().create_graph_builder().begin_building()
339+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
340+
primary, sibling = gb.split(2)
341+
try:
342+
with pytest.raises(RuntimeError, match="forked"):
343+
_ = sibling.graph_definition
344+
finally:
345+
sibling = GraphBuilder.join(primary, sibling)
346+
sibling.end_building()
347+
348+
349+
def test_graph_definition_shares_ownership(init_cuda):
350+
"""Closing the builder must not invalidate a held GraphDefinition."""
351+
mod = compile_common_kernels()
352+
empty_kernel = mod.get_kernel("empty_kernel")
353+
354+
gb = Device().create_graph_builder().begin_building()
355+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
356+
gb.end_building()
357+
358+
gd = gb.graph_definition
359+
gb.close()
360+
# The shared CUgraph keeps the graph alive.
361+
assert len(gd.nodes()) == 1
362+
363+
364+
def test_graph_definition_round_trips_through_explicit_api(init_cuda):
365+
"""Mutating via the explicit API survives complete() and runs correctly."""
366+
mod = compile_common_kernels()
367+
add_one = mod.get_kernel("add_one")
368+
369+
launch_stream = Device().create_stream()
370+
mr = LegacyPinnedMemoryResource()
371+
b = mr.allocate(4)
372+
arr = np.from_dlpack(b).view(np.int32)
373+
arr[0] = 0
374+
375+
gb = launch_stream.create_graph_builder().begin_building()
376+
launch(gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
377+
gb.end_building()
378+
379+
# Add a second add_one through the explicit GraphDefinition view.
380+
gd = gb.graph_definition
381+
captured_node = next(iter(gd.nodes()))
382+
captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
383+
assert len(gd.nodes()) == 2
384+
385+
graph = gb.complete()
386+
graph.launch(launch_stream)
387+
launch_stream.sync()
388+
assert arr[0] == 2
389+
390+
b.close()
391+
392+
393+
@requires_module(np, "2.1")
394+
def test_graph_definition_hybrid_conditional_body(init_cuda):
395+
"""Populate a conditional body entirely through the explicit API.
396+
397+
This is the headline hybrid flow enabled by the new property:
398+
``if_then`` returns a ``GraphBuilder`` for the body, but instead of
399+
calling ``begin_building`` and capturing into it, we reach for
400+
``graph_definition`` and add nodes through the explicit API.
401+
"""
402+
mod = compile_conditional_kernels(int)
403+
add_one = mod.get_kernel("add_one")
404+
set_handle = mod.get_kernel("set_handle")
405+
406+
launch_stream = Device().create_stream()
407+
mr = LegacyPinnedMemoryResource()
408+
b = mr.allocate(4)
409+
arr = np.from_dlpack(b).view(np.int32)
410+
arr[0] = 0
411+
412+
gb = Device().create_graph_builder().begin_building()
413+
condition = try_create_condition(gb)
414+
launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1)
415+
body_gb = gb.if_then(condition)
416+
417+
# Skip body_gb.begin_building() entirely -- the body graph already
418+
# exists at conditional-node creation time and is exposed here.
419+
body_def = body_gb.graph_definition
420+
assert isinstance(body_def, GraphDefinition)
421+
assert len(body_def.nodes()) == 0
422+
body_def.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
423+
424+
graph = gb.end_building().complete()
425+
graph.launch(launch_stream)
426+
launch_stream.sync()
427+
assert arr[0] == 1
428+
429+
b.close()
430+
431+
432+
@requires_module(np, "2.1")
433+
def test_graph_definition_conditional_body_after_capture(init_cuda):
434+
"""Capture into a conditional body, then augment it via the explicit API."""
435+
mod = compile_conditional_kernels(int)
436+
add_one = mod.get_kernel("add_one")
437+
set_handle = mod.get_kernel("set_handle")
438+
439+
launch_stream = Device().create_stream()
440+
mr = LegacyPinnedMemoryResource()
441+
b = mr.allocate(4)
442+
arr = np.from_dlpack(b).view(np.int32)
443+
arr[0] = 0
444+
445+
gb = Device().create_graph_builder().begin_building()
446+
condition = try_create_condition(gb)
447+
launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1)
448+
body_gb = gb.if_then(condition).begin_building()
449+
450+
# Capture one increment into the body.
451+
launch(body_gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
452+
body_gb.end_building()
453+
454+
# Add a second increment via the explicit API on the same body graph.
455+
body_def = body_gb.graph_definition
456+
captured_node = next(iter(body_def.nodes()))
457+
captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
458+
assert len(body_def.nodes()) == 2
459+
460+
graph = gb.end_building().complete()
461+
graph.launch(launch_stream)
462+
launch_stream.sync()
463+
assert arr[0] == 2
464+
465+
b.close()
466+
467+
468+
@requires_module(np, "2.1")
469+
def test_graph_definition_conditional_body_during_capture_raises(init_cuda):
470+
"""The CAPTURING-state guard fires for conditional bodies too."""
471+
gb = Device().create_graph_builder().begin_building()
472+
condition = try_create_condition(gb)
473+
body_gb = gb.if_then(condition).begin_building()
474+
try:
475+
with pytest.raises(RuntimeError, match="capture is in"):
476+
_ = body_gb.graph_definition
477+
finally:
478+
body_gb.end_building()
479+
gb.end_building()

0 commit comments

Comments
 (0)