|
1 | | -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | 4 | """GraphBuilder stream capture tests.""" |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import pytest |
8 | | -from helpers.graph_kernels import compile_common_kernels |
| 8 | +from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels |
9 | 9 | from helpers.marks import requires_module |
| 10 | +from helpers.misc import try_create_condition |
10 | 11 |
|
11 | 12 | from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch |
| 13 | +from cuda.core.graph import GraphDefinition |
12 | 14 |
|
13 | 15 |
|
14 | 16 | def test_graph_is_building(init_cuda): |
@@ -288,3 +290,190 @@ def test_graph_stream_lifetime(init_cuda): |
288 | 290 |
|
289 | 291 | # Destroy the stream |
290 | 292 | stream.close() |
| 293 | + |
| 294 | + |
| 295 | +# --------------------------------------------------------------------------- |
| 296 | +# GraphBuilder.graph_definition |
| 297 | +# --------------------------------------------------------------------------- |
| 298 | + |
| 299 | + |
| 300 | +def test_graph_definition_returns_graph_definition_after_end_building(init_cuda): |
| 301 | + """Primary builder exposes its captured graph as a GraphDefinition after end_building().""" |
| 302 | + mod = compile_common_kernels() |
| 303 | + empty_kernel = mod.get_kernel("empty_kernel") |
| 304 | + |
| 305 | + gb = Device().create_graph_builder().begin_building() |
| 306 | + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) |
| 307 | + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) |
| 308 | + gb.end_building() |
| 309 | + |
| 310 | + gd = gb.graph_definition |
| 311 | + assert isinstance(gd, GraphDefinition) |
| 312 | + # The captured graph must contain the launched kernels. |
| 313 | + assert len(gd.nodes()) == 2 |
| 314 | + |
| 315 | + |
| 316 | +def test_graph_definition_raises_before_begin_building(init_cuda): |
| 317 | + """Primary builder has no graph allocated before begin_building().""" |
| 318 | + gb = Device().create_graph_builder() |
| 319 | + with pytest.raises(RuntimeError, match="before begin_building"): |
| 320 | + _ = gb.graph_definition |
| 321 | + |
| 322 | + |
| 323 | +def test_graph_definition_raises_during_capture(init_cuda): |
| 324 | + """graph_definition is unsafe while the driver is actively capturing.""" |
| 325 | + gb = Device().create_graph_builder().begin_building() |
| 326 | + try: |
| 327 | + with pytest.raises(RuntimeError, match="capture is in"): |
| 328 | + _ = gb.graph_definition |
| 329 | + finally: |
| 330 | + gb.end_building() |
| 331 | + |
| 332 | + |
| 333 | +def test_graph_definition_raises_for_forked(init_cuda): |
| 334 | + """Forked builders share the primary's graph; their property must raise.""" |
| 335 | + mod = compile_common_kernels() |
| 336 | + empty_kernel = mod.get_kernel("empty_kernel") |
| 337 | + |
| 338 | + gb = Device().create_graph_builder().begin_building() |
| 339 | + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) |
| 340 | + primary, sibling = gb.split(2) |
| 341 | + try: |
| 342 | + with pytest.raises(RuntimeError, match="forked"): |
| 343 | + _ = sibling.graph_definition |
| 344 | + finally: |
| 345 | + sibling = GraphBuilder.join(primary, sibling) |
| 346 | + sibling.end_building() |
| 347 | + |
| 348 | + |
| 349 | +def test_graph_definition_shares_ownership(init_cuda): |
| 350 | + """Closing the builder must not invalidate a held GraphDefinition.""" |
| 351 | + mod = compile_common_kernels() |
| 352 | + empty_kernel = mod.get_kernel("empty_kernel") |
| 353 | + |
| 354 | + gb = Device().create_graph_builder().begin_building() |
| 355 | + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) |
| 356 | + gb.end_building() |
| 357 | + |
| 358 | + gd = gb.graph_definition |
| 359 | + gb.close() |
| 360 | + # The shared CUgraph keeps the graph alive. |
| 361 | + assert len(gd.nodes()) == 1 |
| 362 | + |
| 363 | + |
| 364 | +def test_graph_definition_round_trips_through_explicit_api(init_cuda): |
| 365 | + """Mutating via the explicit API survives complete() and runs correctly.""" |
| 366 | + mod = compile_common_kernels() |
| 367 | + add_one = mod.get_kernel("add_one") |
| 368 | + |
| 369 | + launch_stream = Device().create_stream() |
| 370 | + mr = LegacyPinnedMemoryResource() |
| 371 | + b = mr.allocate(4) |
| 372 | + arr = np.from_dlpack(b).view(np.int32) |
| 373 | + arr[0] = 0 |
| 374 | + |
| 375 | + gb = launch_stream.create_graph_builder().begin_building() |
| 376 | + launch(gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) |
| 377 | + gb.end_building() |
| 378 | + |
| 379 | + # Add a second add_one through the explicit GraphDefinition view. |
| 380 | + gd = gb.graph_definition |
| 381 | + captured_node = next(iter(gd.nodes())) |
| 382 | + captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) |
| 383 | + assert len(gd.nodes()) == 2 |
| 384 | + |
| 385 | + graph = gb.complete() |
| 386 | + graph.launch(launch_stream) |
| 387 | + launch_stream.sync() |
| 388 | + assert arr[0] == 2 |
| 389 | + |
| 390 | + b.close() |
| 391 | + |
| 392 | + |
| 393 | +@requires_module(np, "2.1") |
| 394 | +def test_graph_definition_hybrid_conditional_body(init_cuda): |
| 395 | + """Populate a conditional body entirely through the explicit API. |
| 396 | +
|
| 397 | + This is the headline hybrid flow enabled by the new property: |
| 398 | + ``if_then`` returns a ``GraphBuilder`` for the body, but instead of |
| 399 | + calling ``begin_building`` and capturing into it, we reach for |
| 400 | + ``graph_definition`` and add nodes through the explicit API. |
| 401 | + """ |
| 402 | + mod = compile_conditional_kernels(int) |
| 403 | + add_one = mod.get_kernel("add_one") |
| 404 | + set_handle = mod.get_kernel("set_handle") |
| 405 | + |
| 406 | + launch_stream = Device().create_stream() |
| 407 | + mr = LegacyPinnedMemoryResource() |
| 408 | + b = mr.allocate(4) |
| 409 | + arr = np.from_dlpack(b).view(np.int32) |
| 410 | + arr[0] = 0 |
| 411 | + |
| 412 | + gb = Device().create_graph_builder().begin_building() |
| 413 | + condition = try_create_condition(gb) |
| 414 | + launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1) |
| 415 | + body_gb = gb.if_then(condition) |
| 416 | + |
| 417 | + # Skip body_gb.begin_building() entirely -- the body graph already |
| 418 | + # exists at conditional-node creation time and is exposed here. |
| 419 | + body_def = body_gb.graph_definition |
| 420 | + assert isinstance(body_def, GraphDefinition) |
| 421 | + assert len(body_def.nodes()) == 0 |
| 422 | + body_def.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) |
| 423 | + |
| 424 | + graph = gb.end_building().complete() |
| 425 | + graph.launch(launch_stream) |
| 426 | + launch_stream.sync() |
| 427 | + assert arr[0] == 1 |
| 428 | + |
| 429 | + b.close() |
| 430 | + |
| 431 | + |
| 432 | +@requires_module(np, "2.1") |
| 433 | +def test_graph_definition_conditional_body_after_capture(init_cuda): |
| 434 | + """Capture into a conditional body, then augment it via the explicit API.""" |
| 435 | + mod = compile_conditional_kernels(int) |
| 436 | + add_one = mod.get_kernel("add_one") |
| 437 | + set_handle = mod.get_kernel("set_handle") |
| 438 | + |
| 439 | + launch_stream = Device().create_stream() |
| 440 | + mr = LegacyPinnedMemoryResource() |
| 441 | + b = mr.allocate(4) |
| 442 | + arr = np.from_dlpack(b).view(np.int32) |
| 443 | + arr[0] = 0 |
| 444 | + |
| 445 | + gb = Device().create_graph_builder().begin_building() |
| 446 | + condition = try_create_condition(gb) |
| 447 | + launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1) |
| 448 | + body_gb = gb.if_then(condition).begin_building() |
| 449 | + |
| 450 | + # Capture one increment into the body. |
| 451 | + launch(body_gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) |
| 452 | + body_gb.end_building() |
| 453 | + |
| 454 | + # Add a second increment via the explicit API on the same body graph. |
| 455 | + body_def = body_gb.graph_definition |
| 456 | + captured_node = next(iter(body_def.nodes())) |
| 457 | + captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) |
| 458 | + assert len(body_def.nodes()) == 2 |
| 459 | + |
| 460 | + graph = gb.end_building().complete() |
| 461 | + graph.launch(launch_stream) |
| 462 | + launch_stream.sync() |
| 463 | + assert arr[0] == 2 |
| 464 | + |
| 465 | + b.close() |
| 466 | + |
| 467 | + |
| 468 | +@requires_module(np, "2.1") |
| 469 | +def test_graph_definition_conditional_body_during_capture_raises(init_cuda): |
| 470 | + """The CAPTURING-state guard fires for conditional bodies too.""" |
| 471 | + gb = Device().create_graph_builder().begin_building() |
| 472 | + condition = try_create_condition(gb) |
| 473 | + body_gb = gb.if_then(condition).begin_building() |
| 474 | + try: |
| 475 | + with pytest.raises(RuntimeError, match="capture is in"): |
| 476 | + _ = body_gb.graph_definition |
| 477 | + finally: |
| 478 | + body_gb.end_building() |
| 479 | + gb.end_building() |
0 commit comments