|
5 | 5 | import numpy as np
|
6 | 6 | import pytest
|
7 | 7 |
|
8 |
| -from .. import ones, arange, reshape, asarray, result_type, all, equal |
| 8 | +from .. import ones, arange, reshape, asarray, result_type, all, equal, stack |
9 | 9 | from .._array_object import Array, CPU_DEVICE, Device
|
10 | 10 | from .._dtypes import (
|
11 | 11 | _all_dtypes,
|
@@ -101,41 +101,65 @@ def test_validate_index():
|
101 | 101 | assert_raises(IndexError, lambda: a[idx])
|
102 | 102 |
|
103 | 103 |
|
104 |
| -def test_indexing_arrays(): |
| 104 | +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) |
| 105 | +def test_indexing_arrays(device): |
105 | 106 | # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
|
| 107 | + device = None if device is None else Device(device) |
106 | 108 |
|
107 | 109 | # 1D array
|
108 |
| - a = arange(5) |
109 |
| - idx = asarray([1, 0, 1, 2, -1]) |
| 110 | + a = arange(5, device=device) |
| 111 | + idx = asarray([1, 0, 1, 2, -1], device=device) |
110 | 112 | a_idx = a[idx]
|
111 | 113 |
|
112 |
| - a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) |
| 114 | + a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])]) |
113 | 115 | assert all(a_idx == a_idx_loop)
|
| 116 | + assert a_idx.shape == idx.shape |
| 117 | + assert a.device == idx.device == a_idx.device |
114 | 118 |
|
115 | 119 | # setitem with arrays is not allowed
|
116 | 120 | with assert_raises(IndexError):
|
117 | 121 | a[idx] = 42
|
118 | 122 |
|
119 | 123 | # mixed array and integer indexing
|
120 |
| - a = reshape(arange(3*4), (3, 4)) |
121 |
| - idx = asarray([1, 0, 1, 2, -1]) |
| 124 | + a = reshape(arange(3*4, device=device), (3, 4)) |
| 125 | + idx = asarray([1, 0, 1, 2, -1], device=device) |
122 | 126 | a_idx = a[idx, 1]
|
123 |
| - |
124 |
| - a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) |
| 127 | + a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])]) |
125 | 128 | assert all(a_idx == a_idx_loop)
|
| 129 | + assert a_idx.shape == idx.shape |
| 130 | + assert a.device == idx.device == a_idx.device |
126 | 131 |
|
127 | 132 | # index with two arrays
|
128 | 133 | a_idx = a[idx, idx]
|
129 |
| - a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) |
| 134 | + a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])]) |
130 | 135 | assert all(a_idx == a_idx_loop)
|
| 136 | + assert a_idx.shape == a_idx.shape |
| 137 | + assert a.device == idx.device == a_idx.device |
131 | 138 |
|
132 | 139 | # setitem with arrays is not allowed
|
133 | 140 | with assert_raises(IndexError):
|
134 | 141 | a[idx, idx] = 42
|
135 | 142 |
|
136 | 143 | # smoke test indexing with ndim > 1 arrays
|
137 | 144 | idx = idx[..., None]
|
138 |
| - a[idx, idx] |
| 145 | + a_idx = a[idx, idx] |
| 146 | + assert a.device == idx.device == a_idx.device |
| 147 | + |
| 148 | + |
| 149 | +def test_indexing_arrays_different_devices(): |
| 150 | + # Ensure indexing via array on different device errors |
| 151 | + device1 = Device("CPU_DEVICE") |
| 152 | + device2 = Device("device1") |
| 153 | + |
| 154 | + a = arange(5, device=device1) |
| 155 | + idx1 = asarray([1, 0, 1, 2, -1], device=device2) |
| 156 | + idx2 = asarray([1, 0, 1, 2, -1], device=device1) |
| 157 | + |
| 158 | + with pytest.raises(ValueError, match="Array indexing is only allowed when"): |
| 159 | + a[idx1] |
| 160 | + |
| 161 | + with pytest.raises(ValueError, match="Array indexing is only allowed when"): |
| 162 | + a[idx1, idx2] |
139 | 163 |
|
140 | 164 |
|
141 | 165 | def test_promoted_scalar_inherits_device():
|
|
0 commit comments