Skip to content

Commit 430e890

Browse files
authored
Support cooperative launch (#676)
* support cooperative launch * update dev attr test * use f-string * add a naive skipif_need_cuda_headers fixture * micro-optimization
1 parent 6c8ab73 commit 430e890

File tree

9 files changed

+97
-4
lines changed

9 files changed

+97
-4
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,17 @@ def can_use_host_pointer_for_registered_mem(self) -> bool:
701701
)
702702
)
703703

704+
# TODO: A few attrs are missing here (NVIDIA/cuda-python#675)
705+
706+
@property
707+
def cooperative_launch(self) -> bool:
708+
"""
709+
True if device supports launching cooperative kernels, False if not.
710+
"""
711+
return bool(self._get_cached_attribute(driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH))
712+
713+
# TODO: A few attrs are missing here (NVIDIA/cuda-python#675)
714+
704715
@property
705716
def max_shared_memory_per_block_optin(self) -> int:
706717
"""

cuda_core/cuda/core/experimental/_launch_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,15 @@ class LaunchConfig:
5858
cluster: Union[tuple, int] = None
5959
block: Union[tuple, int] = None
6060
shmem_size: Optional[int] = None
61+
cooperative_launch: Optional[bool] = False
6162

6263
def __post_init__(self):
6364
_lazy_init()
6465
self.grid = cast_to_3_tuple("LaunchConfig.grid", self.grid)
6566
self.block = cast_to_3_tuple("LaunchConfig.block", self.block)
67+
# FIXME: Calling Device() strictly speaking is not quite right; we should instead
68+
# look up the device from stream. We probably need to defer the checks related to
69+
# device compute capability or attributes.
6670
# thread block clusters are supported starting H100
6771
if self.cluster is not None:
6872
if not _use_ex:
@@ -77,6 +81,8 @@ def __post_init__(self):
7781
self.cluster = cast_to_3_tuple("LaunchConfig.cluster", self.cluster)
7882
if self.shmem_size is None:
7983
self.shmem_size = 0
84+
if self.cooperative_launch and not Device().properties.cooperative_launch:
85+
raise CUDAError("cooperative kernels are not supported on this device")
8086

8187

8288
def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig:
@@ -92,6 +98,11 @@ def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig:
9298
dim = attr.value.clusterDim
9399
dim.x, dim.y, dim.z = config.cluster
94100
attrs.append(attr)
101+
if config.cooperative_launch:
102+
attr = driver.CUlaunchAttribute()
103+
attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_COOPERATIVE
104+
attr.value.cooperative = 1
105+
attrs.append(attr)
95106
drv_cfg.numAttrs = len(attrs)
96107
drv_cfg.attrs = attrs
97108
return drv_cfg

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from cuda.core.experimental._stream import Stream
1010
from cuda.core.experimental._utils.clear_error_support import assert_type
1111
from cuda.core.experimental._utils.cuda_utils import (
12+
_reduce_3_tuple,
1213
check_or_create_options,
1314
driver,
1415
get_binding_version,
@@ -78,6 +79,8 @@ def launch(stream, config, kernel, *kernel_args):
7879
if _use_ex:
7980
drv_cfg = _to_native_launch_config(config)
8081
drv_cfg.hStream = stream.handle
82+
if config.cooperative_launch:
83+
_check_cooperative_launch(kernel, config, stream)
8184
handle_return(driver.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
8285
else:
8386
# TODO: check if config has any unsupported attrs
@@ -86,3 +89,16 @@ def launch(stream, config, kernel, *kernel_args):
8689
int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream.handle, args_ptr, 0
8790
)
8891
)
92+
93+
94+
def _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stream):
95+
dev = stream.device
96+
num_sm = dev.properties.multiprocessor_count
97+
max_grid_size = (
98+
kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_3_tuple(config.block), config.shmem_size) * num_sm
99+
)
100+
if _reduce_3_tuple(config.grid) > max_grid_size:
101+
# For now let's try not to be smart and adjust the grid size behind users' back.
102+
# We explicitly ask users to adjust.
103+
x, y, z = config.grid
104+
raise ValueError(f"The specified grid size ({x} * {y} * {z}) exceeds the limit ({max_grid_size})")

cuda_core/cuda/core/experimental/_utils/cuda_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def cast_to_3_tuple(label, cfg):
4848
return cfg + (1,) * (3 - len(cfg))
4949

5050

51+
def _reduce_3_tuple(t: tuple):
52+
return t[0] * t[1] * t[2]
53+
54+
5155
def _check_driver_error(error):
5256
if error == driver.CUresult.CUDA_SUCCESS:
5357
return

cuda_core/docs/source/release/0.3.0-notes.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ New features
2222

2323
- :class:`Kernel` adds :property:`Kernel.num_arguments` and :property:`Kernel.arguments_info` for introspection of kernel arguments. (#612)
2424
- Add pythonic access to kernel occupancy calculation functions via :property:`Kernel.occupancy`. (#648)
25+
- Support launching cooperative kernels by setting :property:`LaunchConfig.cooperative_launch` to `True`.
2526

2627
New examples
2728
------------
@@ -31,4 +32,4 @@ Fixes and enhancements
3132
----------------------
3233

3334
- An :class:`Event` can now be used to look up its corresponding device and context using the ``.device`` and ``.context`` attributes respectively.
34-
- The :func:`launch` function's handling of fp16 scalars was incorrect and is fixed
35+
- The :func:`launch` function's handling of fp16 scalars was incorrect and is fixed.

cuda_core/tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ def pop_all_contexts():
6969
os.environ.get("CUDA_PYTHON_TESTING_WITH_COMPUTE_SANITIZER", "0") == "1",
7070
reason="The compute-sanitizer is running, and this test causes an API error.",
7171
)
72+
73+
74+
# TODO: make the fixture more sophisticated using path finder
75+
skipif_need_cuda_headers = pytest.mark.skipif(os.environ.get("CUDA_PATH") is None, reason="need CUDA header")

cuda_core/tests/test_device.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def test_compute_capability():
191191
("concurrent_managed_access", bool),
192192
("compute_preemption_supported", bool),
193193
("can_use_host_pointer_for_registered_mem", bool),
194+
("cooperative_launch", bool),
194195
("max_shared_memory_per_block_optin", int),
195196
("pageable_memory_access_uses_host_page_tables", bool),
196197
("direct_managed_mem_access_from_host", bool),

cuda_core/tests/test_event.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
import pytest
10-
from conftest import skipif_testing_with_compute_sanitizer
10+
from conftest import skipif_need_cuda_headers, skipif_testing_with_compute_sanitizer
1111

1212
import cuda.core.experimental
1313
from cuda.core.experimental import Device, EventOptions, LaunchConfig, Program, ProgramOptions, launch
@@ -114,9 +114,8 @@ def test_error_timing_recorded():
114114
event3 - event2
115115

116116

117-
# TODO: improve this once path finder can find headers
118117
@skipif_testing_with_compute_sanitizer
119-
@pytest.mark.skipif(os.environ.get("CUDA_PATH") is None, reason="need libcu++ header")
118+
@skipif_need_cuda_headers # libcu++
120119
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
121120
def test_error_timing_incomplete():
122121
device = Device()

cuda_core/tests/test_launcher.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import pytest
10+
from conftest import skipif_need_cuda_headers
1011

1112
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
1213
from cuda.core.experimental._memory import _DefaultPinnedMemorySource
@@ -152,3 +153,48 @@ def test_launch_scalar_argument(python_type, cpp_type, init_value):
152153

153154
# Check result
154155
assert arr[0] == init_value, f"Expected {init_value}, got {arr[0]}"
156+
157+
158+
@skipif_need_cuda_headers # cg
159+
def test_cooperative_launch():
160+
dev = Device()
161+
dev.set_current()
162+
s = dev.create_stream(options={"nonblocking": True})
163+
164+
# CUDA kernel templated on type T
165+
code = r"""
166+
#include <cooperative_groups.h>
167+
168+
extern "C" __global__ void test_grid_sync() {
169+
namespace cg = cooperative_groups;
170+
auto grid = cg::this_grid();
171+
grid.sync();
172+
}
173+
"""
174+
175+
# Compile and force instantiation for this type
176+
arch = "".join(f"{i}" for i in dev.compute_capability)
177+
include_path = str(pathlib.Path(os.environ["CUDA_PATH"]) / pathlib.Path("include"))
178+
pro_opts = ProgramOptions(std="c++17", arch=f"sm_{arch}", include_path=include_path)
179+
prog = Program(code, code_type="c++", options=pro_opts)
180+
ker = prog.compile("cubin").get_kernel("test_grid_sync")
181+
182+
# # Launch without setting cooperative_launch
183+
# # Commented out as this seems to be a sticky error...
184+
# config = LaunchConfig(grid=1, block=1)
185+
# launch(s, config, ker)
186+
# from cuda.core.experimental._utils.cuda_utils import CUDAError
187+
# with pytest.raises(CUDAError) as e:
188+
# s.sync()
189+
# assert "CUDA_ERROR_LAUNCH_FAILED" in str(e)
190+
191+
# Crazy grid sizes would not work
192+
block = 128
193+
config = LaunchConfig(grid=dev.properties.max_grid_dim_x // block + 1, block=block, cooperative_launch=True)
194+
with pytest.raises(ValueError):
195+
launch(s, config, ker)
196+
197+
# This works just fine
198+
config = LaunchConfig(grid=1, block=1, cooperative_launch=True)
199+
launch(s, config, ker)
200+
s.sync()

0 commit comments

Comments
 (0)