5
5
import os
6
6
import pathlib
7
7
8
+ import cupy as cp
8
9
import numpy as np
9
10
import pytest
10
11
from conftest import skipif_need_cuda_headers
@@ -219,6 +220,8 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso
219
220
dev = Device ()
220
221
dev .set_current ()
221
222
stream = dev .create_stream ()
223
+ # tell CuPy to use our stream as the current stream:
224
+ cp .cuda .ExternalStream (int (stream .handle )).use ()
222
225
223
226
# Kernel that operates on memory
224
227
code = """
@@ -258,18 +261,13 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso
258
261
# For pinned memory, use numpy
259
262
array = np .from_dlpack (buffer ).view (dtype = dtype )
260
263
else :
261
- # For device memory, use cupy
262
- import cupy as cp
263
-
264
264
array = cp .from_dlpack (buffer ).view (dtype = dtype )
265
265
266
266
# Initialize data with random values
267
267
if mr .is_host_accessible :
268
268
rng = np .random .default_rng ()
269
269
array [:] = rng .random (size , dtype = dtype )
270
270
else :
271
- import cupy as cp
272
-
273
271
rng = cp .random .default_rng ()
274
272
array [:] = rng .random (size , dtype = dtype )
275
273
@@ -288,16 +286,13 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso
288
286
stream .sync ()
289
287
290
288
# Verify kernel operations
291
- if mr .is_host_accessible :
292
- assert np .allclose (array , original * 3.0 ), f"{ memory_resource_class .__name__ } operation failed"
293
- else :
294
- import cupy as cp
295
-
296
- assert cp .allclose (array , original * 3.0 ), f"{ memory_resource_class .__name__ } operation failed"
289
+ assert cp .allclose (array , original * 3.0 ), f"{ memory_resource_class .__name__ } operation failed"
297
290
298
291
# Clean up
299
292
buffer .close (stream )
300
293
stream .close ()
301
294
295
+ cp .cuda .Stream .null .use () # reset CuPy's current stream to the null stream
296
+
302
297
# Verify buffer is properly closed
303
298
assert buffer .handle == 0 , f"{ memory_resource_class .__name__ } buffer should be closed"
0 commit comments