Skip to content

Commit fd8e07b

Browse files
authored
Ensure correct handling of buffers allocated with LegacyPinnedMemoryResource.allocate as kernel parameters (#717)
* Add memory ops example * Fix handling of buffer with int handle * pre-commit fixes * Simplify pinned memory example * Copy pinned memory tests to test_launcher.py * Remove dlpack assertions and address other review comments --------- Co-authored-by: Ashwin Srinath <[email protected]>
1 parent 24fde17 commit fd8e07b

File tree

3 files changed

+254
-2
lines changed

3 files changed

+254
-2
lines changed

cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,13 @@ cdef class ParamHolder:
212212
for i, arg in enumerate(kernel_args):
213213
if isinstance(arg, Buffer):
214214
# we need the address of where the actual buffer address is stored
215-
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
215+
if isinstance(arg.handle, int):
216+
# see note below on handling int arguments
217+
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
218+
continue
219+
else:
220+
# it's a CUdeviceptr:
221+
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
216222
continue
217223
elif isinstance(arg, int):
218224
# Here's the dilemma: We want to have a fast path to pass in Python

cuda_core/examples/memory_ops.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
# ################################################################################
6+
#
7+
# This demo illustrates:
8+
#
9+
# 1. How to use different memory resources to allocate and manage memory
10+
# 2. How to copy data between different memory types
11+
# 3. How to use DLPack to interoperate with other libraries
12+
#
13+
# ################################################################################
14+
15+
import sys
16+
17+
import cupy as cp
18+
import numpy as np
19+
20+
from cuda.core.experimental import (
21+
Device,
22+
LaunchConfig,
23+
LegacyPinnedMemoryResource,
24+
Program,
25+
ProgramOptions,
26+
launch,
27+
)
28+
29+
if np.__version__ < "2.1.0":
30+
print("This example requires NumPy 2.1.0 or later", file=sys.stderr)
31+
sys.exit(0)
32+
33+
# Kernel for memory operations
34+
code = """
35+
extern "C"
36+
__global__ void memory_ops(float* device_data,
37+
float* pinned_data,
38+
size_t N) {
39+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
40+
if (tid < N) {
41+
// Access device memory
42+
device_data[tid] = device_data[tid] + 1.0f;
43+
44+
// Access pinned memory (zero-copy from GPU)
45+
pinned_data[tid] = pinned_data[tid] * 3.0f;
46+
}
47+
}
48+
"""
49+
50+
dev = Device()
51+
dev.set_current()
52+
stream = dev.create_stream()
53+
# tell CuPy to use our stream as the current stream:
54+
cp.cuda.ExternalStream(int(stream.handle)).use()
55+
56+
# Compile kernel
57+
arch = "".join(f"{i}" for i in dev.compute_capability)
58+
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
59+
prog = Program(code, code_type="c++", options=program_options)
60+
mod = prog.compile("cubin")
61+
kernel = mod.get_kernel("memory_ops")
62+
63+
# Create different memory resources
64+
device_mr = dev.memory_resource
65+
pinned_mr = LegacyPinnedMemoryResource()
66+
67+
# Allocate different types of memory
68+
size = 1024
69+
dtype = cp.float32
70+
element_size = dtype().itemsize
71+
total_size = size * element_size
72+
73+
# 1. Device Memory (GPU-only)
74+
device_buffer = device_mr.allocate(total_size, stream=stream)
75+
device_array = cp.from_dlpack(device_buffer).view(dtype=dtype)
76+
77+
# 2. Pinned Memory (CPU memory, GPU accessible)
78+
pinned_buffer = pinned_mr.allocate(total_size, stream=stream)
79+
pinned_array = np.from_dlpack(pinned_buffer).view(dtype=dtype)
80+
81+
# Initialize data
82+
rng = cp.random.default_rng()
83+
device_array[:] = rng.random(size, dtype=dtype)
84+
pinned_array[:] = rng.random(size, dtype=dtype).get()
85+
86+
# Store original values for verification
87+
device_original = device_array.copy()
88+
pinned_original = pinned_array.copy()
89+
90+
# Sync before kernel launch
91+
stream.sync()
92+
93+
# Launch kernel
94+
block = 256
95+
grid = (size + block - 1) // block
96+
config = LaunchConfig(grid=grid, block=block)
97+
98+
launch(stream, config, kernel, device_buffer, pinned_buffer, cp.uint64(size))
99+
stream.sync()
100+
101+
# Verify kernel operations
102+
assert cp.allclose(device_array, device_original + 1.0), "Device memory operation failed"
103+
assert cp.allclose(pinned_array, pinned_original * 3.0), "Pinned memory operation failed"
104+
105+
# Copy data between different memory types
106+
print("\nCopying data between memory types...")
107+
108+
# Copy from device to pinned memory
109+
device_buffer.copy_to(pinned_buffer, stream=stream)
110+
stream.sync()
111+
112+
# Verify the copy operation
113+
assert cp.allclose(pinned_array, device_array), "Device to pinned copy failed"
114+
115+
# Create a new device buffer and copy from pinned
116+
new_device_buffer = device_mr.allocate(total_size, stream=stream)
117+
new_device_array = cp.from_dlpack(new_device_buffer).view(dtype=dtype)
118+
119+
pinned_buffer.copy_to(new_device_buffer, stream=stream)
120+
stream.sync()
121+
122+
# Verify the copy operation
123+
assert cp.allclose(new_device_array, pinned_array), "Pinned to device copy failed"
124+
125+
# Clean up
126+
device_buffer.close(stream)
127+
pinned_buffer.close(stream)
128+
new_device_buffer.close(stream)
129+
stream.close()
130+
cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream
131+
132+
# Verify buffers are properly closed
133+
assert device_buffer.handle == 0, "Device buffer should be closed"
134+
assert pinned_buffer.handle == 0, "Pinned buffer should be closed"
135+
assert new_device_buffer.handle == 0, "New device buffer should be closed"
136+
137+
print("Memory management example completed!")

cuda_core/tests/test_launcher.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@
55
import os
66
import pathlib
77

8+
import cupy as cp
89
import numpy as np
910
import pytest
1011
from conftest import skipif_need_cuda_headers
1112

12-
from cuda.core.experimental import Device, LaunchConfig, LegacyPinnedMemoryResource, Program, ProgramOptions, launch
13+
from cuda.core.experimental import (
14+
Device,
15+
DeviceMemoryResource,
16+
LaunchConfig,
17+
LegacyPinnedMemoryResource,
18+
Program,
19+
ProgramOptions,
20+
launch,
21+
)
22+
from cuda.core.experimental._memory import _SynchronousMemoryResource
1323

1424

1525
def test_launch_config_init(init_cuda):
@@ -197,3 +207,102 @@ def test_cooperative_launch():
197207
config = LaunchConfig(grid=1, block=1, cooperative_launch=True)
198208
launch(s, config, ker)
199209
s.sync()
210+
211+
212+
@pytest.mark.parametrize(
213+
"memory_resource_class",
214+
[
215+
"device_memory_resource", # kludgy, but can go away after #726 is resolved
216+
pytest.param(
217+
LegacyPinnedMemoryResource,
218+
marks=pytest.mark.skipif(
219+
tuple(int(i) for i in np.__version__.split(".")[:3]) < (2, 2, 5),
220+
reason="need numpy 2.2.5+, numpy GH #28632",
221+
),
222+
),
223+
],
224+
)
225+
def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_resource_class):
226+
"""Test that kernels can access memory allocated by memory resources."""
227+
dev = Device()
228+
dev.set_current()
229+
stream = dev.create_stream()
230+
# tell CuPy to use our stream as the current stream:
231+
cp.cuda.ExternalStream(int(stream.handle)).use()
232+
233+
# Kernel that operates on memory
234+
code = """
235+
extern "C"
236+
__global__ void memory_ops(float* data, size_t N) {
237+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
238+
if (tid < N) {
239+
// Access memory (device or pinned)
240+
data[tid] = data[tid] * 3.0f;
241+
}
242+
}
243+
"""
244+
245+
# Compile kernel
246+
arch = "".join(f"{i}" for i in dev.compute_capability)
247+
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
248+
prog = Program(code, code_type="c++", options=program_options)
249+
mod = prog.compile("cubin")
250+
kernel = mod.get_kernel("memory_ops")
251+
252+
# Create memory resource
253+
if memory_resource_class == "device_memory_resource":
254+
if dev.properties.memory_pools_supported:
255+
mr = DeviceMemoryResource(dev.device_id)
256+
else:
257+
mr = _SynchronousMemoryResource(dev.device_id)
258+
else:
259+
mr = memory_resource_class()
260+
261+
# Allocate memory
262+
size = 1024
263+
dtype = np.float32
264+
element_size = dtype().itemsize
265+
total_size = size * element_size
266+
267+
buffer = mr.allocate(total_size, stream=stream)
268+
269+
# Create array view based on memory type
270+
if mr.is_host_accessible:
271+
# For pinned memory, use numpy
272+
array = np.from_dlpack(buffer).view(dtype=dtype)
273+
else:
274+
array = cp.from_dlpack(buffer).view(dtype=dtype)
275+
276+
# Initialize data with random values
277+
if mr.is_host_accessible:
278+
rng = np.random.default_rng()
279+
array[:] = rng.random(size, dtype=dtype)
280+
else:
281+
rng = cp.random.default_rng()
282+
array[:] = rng.random(size, dtype=dtype)
283+
284+
# Store original values for verification
285+
original = array.copy()
286+
287+
# Sync before kernel launch
288+
stream.sync()
289+
290+
# Launch kernel
291+
block = 256
292+
grid = (size + block - 1) // block
293+
config = LaunchConfig(grid=grid, block=block)
294+
295+
launch(stream, config, kernel, buffer, np.uint64(size))
296+
stream.sync()
297+
298+
# Verify kernel operations
299+
assert cp.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed"
300+
301+
# Clean up
302+
buffer.close(stream)
303+
stream.close()
304+
305+
cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream
306+
307+
# Verify buffer is properly closed
308+
assert buffer.handle == 0, f"{memory_resource_class.__name__} buffer should be closed"

0 commit comments

Comments
 (0)