Skip to content

Accept NumPy arrays in advanced indexing #2128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 102 additions & 63 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
if not isinstance(ary_mask, dpt.usm_ndarray):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
if isinstance(ary_mask, dpt.usm_ndarray):
dst_usm_type = dpctl.utils.get_coerced_usm_type(
(ary.usm_type, ary_mask.usm_type)
)
dst_usm_type = dpctl.utils.get_coerced_usm_type(
(ary.usm_type, ary_mask.usm_type)
)
exec_q = dpctl.utils.get_execution_queue(
(ary.sycl_queue, ary_mask.sycl_queue)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"arrays have different associated queues. "
"Use `y.to_device(x.device)` to migrate."
exec_q = dpctl.utils.get_execution_queue(
(ary.sycl_queue, ary_mask.sycl_queue)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"arrays have different associated queues. "
"Use `y.to_device(x.device)` to migrate."
)
elif isinstance(ary_mask, np.ndarray):
dst_usm_type = ary.usm_type
exec_q = ary.sycl_queue
ary_mask = dpt.asarray(
ary_mask, usm_type=dst_usm_type, sycl_queue=exec_q
)
else:
raise TypeError(
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
f"{type(ary_mask)}"
)
ary_nd = ary.ndim
pp = normalize_axis_index(operator.index(axis), ary_nd)
Expand Down Expand Up @@ -837,35 +845,40 @@ def _nonzero_impl(ary):
return res


def _validate_indices(inds, queue_list, usm_type_list):
def _get_indices_queue_usm_type(inds, queue, usm_type):
"""
Utility for validating indices are usm_ndarray of integral dtype or Python
integers. At least one must be an array.
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
dtype or Python integers. At least one must be an array.

For each array, the queue and usm type are appended to `queue_list` and
`usm_type_list`, respectively.
"""
any_usmarray = False
queues = [queue]
usm_types = [usm_type]
any_array = False
for ind in inds:
if isinstance(ind, dpt.usm_ndarray):
any_usmarray = True
if isinstance(ind, (np.ndarray, dpt.usm_ndarray)):
any_array = True
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) "
"type"
)
queue_list.append(ind.sycl_queue)
usm_type_list.append(ind.usm_type)
if isinstance(ind, dpt.usm_ndarray):
queues.append(ind.sycl_queue)
usm_types.append(ind.usm_type)
elif not isinstance(ind, Integral):
raise TypeError(
"all elements of `ind` expected to be usm_ndarrays "
f"or integers, found {type(ind)}"
"all elements of `ind` expected to be usm_ndarrays, "
f"NumPy arrays, or integers, found {type(ind)}"
)
if not any_usmarray:
if not any_array:
raise TypeError(
"at least one element of `inds` expected to be a usm_ndarray"
"at least one element of `inds` expected to be an array"
)
return inds
usm_type = dpctl.utils.get_coerced_usm_type(usm_types)
q = dpctl.utils.get_execution_queue(queues)
return q, usm_type


def _prepare_indices_arrays(inds, q, usm_type):
Expand Down Expand Up @@ -922,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
raise ValueError(
"Invalid value for mode keyword, only 0 or 1 is supported"
)
queues_ = [
ary.sycl_queue,
]
usm_types_ = [
ary.usm_type,
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)

_validate_indices(inds, queues_, usm_types_)
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
exec_q, res_usm_type = _get_indices_queue_usm_type(
inds, ary.sycl_queue, ary.usm_type
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Can not automatically determine where to allocate the "
Expand All @@ -942,8 +949,7 @@ def _take_multi_index(ary, inds, p, mode=0):
"be associated with the same queue."
)

if len(inds) > 1:
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
Expand Down Expand Up @@ -976,21 +982,51 @@ def _place_impl(ary, ary_mask, vals, axis=0):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
if not isinstance(ary_mask, dpt.usm_ndarray):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
if isinstance(ary_mask, dpt.usm_ndarray):
exec_q = dpctl.utils.get_execution_queue(
(
ary.sycl_queue,
ary_mask.sycl_queue,
)
)
exec_q = dpctl.utils.get_execution_queue(
(
ary.sycl_queue,
ary_mask.sycl_queue,
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
(
ary.usm_type,
ary_mask.usm_type,
)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"arrays have different associated queues. "
"Use `y.to_device(x.device)` to migrate."
)
elif isinstance(ary_mask, np.ndarray):
exec_q = ary.sycl_queue
coerced_usm_type = ary.usm_type
ary_mask = dpt.asarray(
ary_mask, usm_type=coerced_usm_type, sycl_queue=exec_q
)
else:
raise TypeError(
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
f"{type(ary_mask)}"
)
)
if exec_q is not None:
if not isinstance(vals, dpt.usm_ndarray):
vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q)
vals = dpt.asarray(
vals,
dtype=ary.dtype,
usm_type=coerced_usm_type,
sycl_queue=exec_q,
)
else:
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
(
coerced_usm_type,
vals.usm_type,
)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"arrays have different associated queues. "
Expand All @@ -1005,7 +1041,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
)
mask_nelems = ary_mask.size
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
cumsum = dpt.empty(
mask_nelems,
dtype=cumsum_dt,
usm_type=coerced_usm_type,
device=ary_mask.device,
)
exec_q = cumsum.sycl_queue
_manager = dpctl.utils.SequentialOrderManager[exec_q]
dep_ev = _manager.submitted_events
Expand Down Expand Up @@ -1048,30 +1089,29 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
raise ValueError(
"Invalid value for mode keyword, only 0 or 1 is supported"
)
if isinstance(vals, dpt.usm_ndarray):
queues_ = [ary.sycl_queue, vals.sycl_queue]
usm_types_ = [ary.usm_type, vals.usm_type]
else:
queues_ = [
ary.sycl_queue,
]
usm_types_ = [
ary.usm_type,
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)

_validate_indices(inds, queues_, usm_types_)
exec_q, coerced_usm_type = _get_indices_queue_usm_type(
inds, ary.sycl_queue, ary.usm_type
)

vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is not None:
if not isinstance(vals, dpt.usm_ndarray):
vals = dpt.asarray(
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
vals,
dtype=ary.dtype,
usm_type=coerced_usm_type,
sycl_queue=exec_q,
)
else:
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
(
coerced_usm_type,
vals.usm_type,
)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Can not automatically determine where to allocate the "
Expand All @@ -1080,8 +1120,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
"be associated with the same queue."
)

if len(inds) > 1:
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
inds = _prepare_indices_arrays(inds, exec_q, coerced_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
Expand Down
13 changes: 7 additions & 6 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numbers
from operator import index
from cpython.buffer cimport PyObject_CheckBuffer
from numpy import ndarray


cdef bint _is_buffer(object o):
Expand Down Expand Up @@ -46,7 +47,7 @@ cdef Py_ssize_t _slice_len(

cdef bint _is_integral(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if isinstance(x, (ndarray, usm_ndarray)):
if x.ndim > 0:
return False
if x.dtype.kind not in "ui":
Expand Down Expand Up @@ -74,7 +75,7 @@ cdef bint _is_integral(object x) except *:

cdef bint _is_boolean(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if isinstance(x, (ndarray, usm_ndarray)):
if x.ndim > 0:
return False
if x.dtype.kind not in "b":
Expand Down Expand Up @@ -185,7 +186,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
raise IndexError(
"Index {0} is out of range for axes 0 with "
"size {1}".format(ind, shape[0]))
elif isinstance(ind, usm_ndarray):
elif isinstance(ind, (ndarray, usm_ndarray)):
return (shape, strides, offset, (ind,), 0)
elif isinstance(ind, tuple):
axes_referenced = 0
Expand Down Expand Up @@ -216,7 +217,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
axes_referenced += 1
if not array_streak_started and array_streak_interrupted:
explicit_index += 1
elif isinstance(i, usm_ndarray):
elif isinstance(i, (ndarray, usm_ndarray)):
if not seen_arrays_yet:
seen_arrays_yet = True
array_streak_started = True
Expand Down Expand Up @@ -302,7 +303,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
array_streak = False
elif _is_integral(ind_i):
if array_streak:
if not isinstance(ind_i, usm_ndarray):
if not isinstance(ind_i, (ndarray, usm_ndarray)):
ind_i = index(ind_i)
# integer will be converted to an array,
# still raise if OOB
Expand Down Expand Up @@ -337,7 +338,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
"Index {0} is out of range for axes "
"{1} with size {2}".format(ind_i, k, shape[k])
)
elif isinstance(ind_i, usm_ndarray):
elif isinstance(ind_i, (ndarray, usm_ndarray)):
if not array_streak:
array_streak = True
if not advanced_start_pos_set:
Expand Down
22 changes: 22 additions & 0 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,28 @@ def test_advanced_slice16():
assert isinstance(y, dpt.usm_ndarray)


def test_integer_indexing_numpy_array():
q = get_queue_or_skip()
ii = np.asarray([1, 2])
x = dpt.arange(10, dtype="i4", sycl_queue=q)
y = x[ii]
assert isinstance(y, dpt.usm_ndarray)
assert y.shape == ii.shape
assert dpt.all(x[1:3] == y)


def test_boolean_indexing_numpy_array():
q = get_queue_or_skip()
ii = np.asarray(
[False, True, True, False, False, False, False, False, False, False]
)
x = dpt.arange(10, dtype="i4", sycl_queue=q)
y = x[ii]
assert isinstance(y, dpt.usm_ndarray)
assert y.shape == (2,)
assert dpt.all(x[1:3] == y)


def test_boolean_indexing_validation():
get_queue_or_skip()
x = dpt.zeros(10, dtype="i4")
Expand Down
Loading