Skip to content

Commit ea5deb1

Browse files
authored
BUG: fix tuple array indexing
reviewed at #139
1 parent 7fa1667 commit ea5deb1

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

array_api_strict/_array_object.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,24 @@ def __getitem__(
716716
# Note: Only indices required by the spec are allowed. See the
717717
# docstring of _validate_index
718718
self._validate_index(key, op="getitem")
719-
# Indexing self._array with array_api_strict arrays can be erroneous
720-
np_key = key._array if isinstance(key, Array) else key
719+
if isinstance(key, Array):
720+
key = (key,)
721+
np_key = key
722+
devices = {self.device}
723+
if isinstance(key, tuple):
724+
devices.update(
725+
[subkey.device for subkey in key if hasattr(subkey, "device")]
726+
)
727+
if len(devices) > 1:
728+
raise ValueError(
729+
"Array indexing is only allowed when array to be indexed and all "
730+
"indexing arrays are on the same device."
731+
)
732+
# Indexing self._array with array_api_strict arrays can be erroneous
733+
# e.g., when using non-default device
734+
np_key = tuple(
735+
subkey._array if isinstance(subkey, Array) else subkey for subkey in key
736+
)
721737
res = self._array.__getitem__(np_key)
722738
return self._new(res, device=self.device)
723739

array_api_strict/tests/test_array_object.py

+35-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pytest
77

8-
from .. import ones, arange, reshape, asarray, result_type, all, equal
8+
from .. import ones, arange, reshape, asarray, result_type, all, equal, stack
99
from .._array_object import Array, CPU_DEVICE, Device
1010
from .._dtypes import (
1111
_all_dtypes,
@@ -101,41 +101,65 @@ def test_validate_index():
101101
assert_raises(IndexError, lambda: a[idx])
102102

103103

104-
def test_indexing_arrays():
104+
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
105+
def test_indexing_arrays(device):
105106
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
107+
device = None if device is None else Device(device)
106108

107109
# 1D array
108-
a = arange(5)
109-
idx = asarray([1, 0, 1, 2, -1])
110+
a = arange(5, device=device)
111+
idx = asarray([1, 0, 1, 2, -1], device=device)
110112
a_idx = a[idx]
111113

112-
a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
114+
a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])])
113115
assert all(a_idx == a_idx_loop)
116+
assert a_idx.shape == idx.shape
117+
assert a.device == idx.device == a_idx.device
114118

115119
# setitem with arrays is not allowed
116120
with assert_raises(IndexError):
117121
a[idx] = 42
118122

119123
# mixed array and integer indexing
120-
a = reshape(arange(3*4), (3, 4))
121-
idx = asarray([1, 0, 1, 2, -1])
124+
a = reshape(arange(3*4, device=device), (3, 4))
125+
idx = asarray([1, 0, 1, 2, -1], device=device)
122126
a_idx = a[idx, 1]
123-
124-
a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
127+
a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])])
125128
assert all(a_idx == a_idx_loop)
129+
assert a_idx.shape == idx.shape
130+
assert a.device == idx.device == a_idx.device
126131

127132
# index with two arrays
128133
a_idx = a[idx, idx]
129-
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
134+
a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])])
130135
assert all(a_idx == a_idx_loop)
136+
assert a_idx.shape == a_idx.shape
137+
assert a.device == idx.device == a_idx.device
131138

132139
# setitem with arrays is not allowed
133140
with assert_raises(IndexError):
134141
a[idx, idx] = 42
135142

136143
# smoke test indexing with ndim > 1 arrays
137144
idx = idx[..., None]
138-
a[idx, idx]
145+
a_idx = a[idx, idx]
146+
assert a.device == idx.device == a_idx.device
147+
148+
149+
def test_indexing_arrays_different_devices():
150+
# Ensure indexing via array on different device errors
151+
device1 = Device("CPU_DEVICE")
152+
device2 = Device("device1")
153+
154+
a = arange(5, device=device1)
155+
idx1 = asarray([1, 0, 1, 2, -1], device=device2)
156+
idx2 = asarray([1, 0, 1, 2, -1], device=device1)
157+
158+
with pytest.raises(ValueError, match="Array indexing is only allowed when"):
159+
a[idx1]
160+
161+
with pytest.raises(ValueError, match="Array indexing is only allowed when"):
162+
a[idx1, idx2]
139163

140164

141165
def test_promoted_scalar_inherits_device():

0 commit comments

Comments
 (0)