-
Notifications
You must be signed in to change notification settings - Fork 30
Accept NumPy arrays in advanced indexing #2128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
View rendered docs @ https://intelpython.github.io/dpctl/pulls/2128/index.html |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_79 ran successfully. |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_81 ran successfully. |
exec_q = dpctl.utils.get_execution_queue( | ||
(ary.sycl_queue, ary_mask.sycl_queue) | ||
) | ||
if exec_q is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found a suspicious behavior:
q1 = dpctl.SyclQueue(property="in_order")
q2 = dpctl.SyclQueue(property="enable_profiling")
a = dpt.ones(10, sycl_queue=q1)
mask = dpt.zeros(10, dtype='?', sycl_queue=q2)
# raised error as expected:
a[mask]
---------------------------------------------------------------------------
ExecutionPlacementError Traceback (most recent call last)
Cell In[30], line 1
----> 1 a[mask]
File dpctl/tensor/_usmarray.pyx:999, in dpctl.tensor._usmarray.usm_ndarray.__getitem__()
File /localdisk/work/antonvol/soft/miniforge3/envs/dpnp_dev/lib/python3.12/site-packages/dpctl/tensor/_copy_utils.py:767, in _extract_impl(ary, ary_mask, axis)
763 exec_q = dpctl.utils.get_execution_queue(
764 (ary.sycl_queue, ary_mask.sycl_queue)
765 )
766 if exec_q is None:
--> 767 raise dpctl.utils.ExecutionPlacementError(
768 "arrays have different associated queues. "
769 "Use `y.to_device(x.device)` to migrate."
770 )
771 elif isinstance(ary_mask, np.ndarray):
772 dst_usm_type = ary.usm_type
ExecutionPlacementError: arrays have different associated queues. Use `y.to_device(x.device)` to migrate.
# no error for 0-d array:
mask = dpt.asarray(1, dtype='?', sycl_queue=q2)
a[mask]
# Out: usm_ndarray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
There will be no _extract_impl
called, but seems the check is missing in another implementation branch inside __getitem__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is actually an expected behavior: when a zero-dim array is passed as mask (or index), it is brought to the host as a scalar. This especially necessary in the case of boolean masks
see: #1136
|
||
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) | ||
exec_q = dpctl.utils.get_execution_queue(queues_) | ||
if exec_q is not None: | ||
if not isinstance(vals, dpt.usm_ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there missing coercion of usm_type for an input value?
else:
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
vals_usm_type = dpctl.utils.get_coerced_usm_type((vals_usm_type, vals.usm_type))
y = x[ii] | ||
assert isinstance(y, dpt.usm_ndarray) | ||
assert y.shape == ii.shape | ||
assert dpt.all(dpt.asarray(ii, sycl_queue=q) == y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to use
assert dpt.all(dpt.asarray(ii, sycl_queue=q) == y) | |
assert dpt.all(x[1:3] == y) |
since currently it works only because there is dpt.arange(10, ...)
used to allocate data
This PR proposes allowing NumPy arrays when indexing
usm_ndarray
This can be useful where attempting to use
usm_ndarray
in a library which generates indices with NumPy, and which expects to only handle devices implicitly through tensor operations (i.e., a library which can take an arbitrary tensor input, but generates indices withnumpy.random
and expects these indices to work on the input tensor)Both boolean and integer advanced indices can accept NumPy arrays with these changes
Closes #2053