-
Notifications
You must be signed in to change notification settings - Fork 33
ENH: Simplify CuPy asarray
and to_device
#314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -775,42 +775,28 @@ def _cupy_to_device( | |
/, | ||
stream: int | Any | None = None, | ||
) -> _CupyArray: | ||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||
from cupy.cuda import Device as _Device # pyright: ignore | ||
from cupy.cuda import stream as stream_module # pyright: ignore | ||
from cupy_backends.cuda.api import runtime # pyright: ignore | ||
import cupy as cp | ||
|
||
if device == x.device: | ||
return x | ||
elif device == "cpu": | ||
if device == "cpu": | ||
# allowing us to use `to_device(x, "cpu")` | ||
# is useful for portable test swapping between | ||
# host and device backends | ||
return x.get() | ||
elif not isinstance(device, _Device): | ||
raise ValueError(f"Unsupported device {device!r}") | ||
else: | ||
# see cupy/cupy#5985 for the reason how we handle device/stream here | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue fixed in 2021 |
||
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] | ||
prev_stream = None | ||
if stream is not None: | ||
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore | ||
# stream can be an int as specified in __dlpack__, or a CuPy stream | ||
if isinstance(stream, int): | ||
stream = cp.cuda.ExternalStream(stream) # pyright: ignore | ||
elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] | ||
pass | ||
else: | ||
raise ValueError("the input stream is not recognized") | ||
stream.use() # pyright: ignore[reportUnknownMemberType] | ||
try: | ||
runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] | ||
arr = x.copy() | ||
finally: | ||
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] | ||
if stream is not None: | ||
prev_stream.use() | ||
return arr | ||
if not isinstance(device, cp.cuda.Device): | ||
raise TypeError(f"Unsupported device type {device!r}") | ||
|
||
if stream is None: | ||
with device: | ||
return cp.asarray(x) | ||
|
||
# stream can be an int as specified in __dlpack__, or a CuPy stream | ||
if isinstance(stream, int): | ||
stream = cp.cuda.ExternalStream(stream) | ||
elif not isinstance(stream, cp.cuda.Stream): | ||
raise TypeError(f"Unsupported stream type {stream!r}") | ||
|
||
with device, stream: | ||
return cp.asarray(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have test coverage for this part? TBH I'm ready to believe it's totally correct but would prefer to have some tests to ensure that it really is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None of this is covered by array-api-tests data-apis/array-api-tests#302.
The previous design cites an upstream issue resolved in 2021.
Fully agree, and the place for it is data-apis/array-api-tests#302 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, testing it in array-api-tests would be great. Two things to consider though:
Re: how old is a version we care? Might want some input from CuPy devs, too. Off the cuff, I'd say we definitely care about 13.4; previous 13.x versions with |
||
|
||
|
||
def _torch_to_device( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,8 +64,6 @@ | |
finfo = get_xp(cp)(_aliases.finfo) | ||
iinfo = get_xp(cp)(_aliases.iinfo) | ||
|
||
_copy_default = object() | ||
|
||
|
||
# asarray also adds the copy keyword, which is not present in numpy 1.0. | ||
def asarray( | ||
|
@@ -79,7 +77,7 @@ def asarray( | |
*, | ||
dtype: Optional[DType] = None, | ||
device: Optional[Device] = None, | ||
copy: Optional[bool] = _copy_default, | ||
copy: Optional[bool] = None, | ||
**kwargs, | ||
) -> Array: | ||
""" | ||
|
@@ -89,25 +87,13 @@ def asarray( | |
specification for more details. | ||
""" | ||
with cp.cuda.Device(device): | ||
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments | ||
# in asarray in numpy/_aliases.py. | ||
if copy is not _copy_default: | ||
# A future version of CuPy will change the meaning of copy=False | ||
# to mean no-copy. We don't know for certain what version it will | ||
# be yet, so to avoid breaking that version, we use a different | ||
# default value for copy so asarray(obj) with no copy kwarg will | ||
# always do the copy-if-needed behavior. | ||
|
||
# This will still need to be updated to remove the | ||
# NotImplementedError for copy=False, but at least this won't | ||
# break the default or existing behavior. | ||
if copy is None: | ||
copy = False | ||
elif copy is False: | ||
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") | ||
kwargs['copy'] = copy | ||
|
||
return cp.array(obj, dtype=dtype, **kwargs) | ||
if copy is None: | ||
return cp.asarray(obj, dtype=dtype, **kwargs) | ||
else: | ||
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs) | ||
if not copy and res is not obj: | ||
raise ValueError("Unable to avoid copy while creating an array as requested") | ||
return res | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part seems to streamline and remove workarounds for presumably older cupy versions.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, the latest cupy still has the same exact issues as the old ones.
We need to formulate policy for a minimum supported version of all backends. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The latest cupy you mean 13.4 or 14.0.a1? One thing to bear in mind with CuPy specifically: IIUC they are still working through the numpy 2.0 transition. 13.4 is sort-of compatible (one goal was to make it import with both numpy 1 and numpy 2), 14.x series is supposed to complete the transition. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 13.4.1. |
||
|
||
|
||
def astype( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pytest | ||
from array_api_compat import device, to_device | ||
|
||
xp = pytest.importorskip("array_api_compat.cupy") | ||
from cupy.cuda import Stream | ||
|
||
|
||
def test_to_device_with_stream(): | ||
devices = xp.__array_namespace_info__().devices() | ||
streams = [ | ||
Stream(), | ||
Stream(non_blocking=True), | ||
Stream(null=True), | ||
Stream(ptds=True), | ||
123, # dlpack stream | ||
] | ||
|
||
a = xp.asarray([1, 2, 3]) | ||
for dev in devices: | ||
for stream in streams: | ||
b = to_device(a, dev, stream=stream) | ||
assert device(b) == dev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoa. I'm confused: this returns a numpy array if
x
was a cupy array. This might be intended, but then the return annotation is misleading? Same for the device argument annotation. So maybe add a comment if a numpydoc docstring withParameters
andReturns
sections is an overkill.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of this has changed. See also #87.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is that gh-87 is still open, and there are no tests, so if we touch this, great, let's take the opportunity to also improve testing.