Skip to content

Commit 9323324

Browse files
committed
Multi-device support for array manipulation
1 parent e0b2a64 commit 9323324

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

array_api_strict/_manipulation_functions.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ def concat(
2525
# Note: Casting rules here are different from the np.concatenate default
2626
# (no for scalars with axis=None, no cross-kind casting)
2727
dtype = result_type(*arrays)
28+
if len({a.device for a in arrays}) > 1:
29+
raise ValueError("concat inputs must all be on the same device")
30+
2831
arrays = tuple(a._array for a in arrays)
29-
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype))
32+
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=arrays[0].device)
3033

3134

3235
def expand_dims(x: Array, /, *, axis: int) -> Array:
@@ -35,7 +38,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array:
3538
3639
See its docstring for more information.
3740
"""
38-
return Array._new(np.expand_dims(x._array, axis))
41+
return Array._new(np.expand_dims(x._array, axis), device=x.device)
3942

4043

4144
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
@@ -44,7 +47,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
4447
4548
See its docstring for more information.
4649
"""
47-
return Array._new(np.flip(x._array, axis=axis))
50+
return Array._new(np.flip(x._array, axis=axis), device=x.device)
4851

4952
@requires_api_version('2023.12')
5053
def moveaxis(
@@ -58,7 +61,7 @@ def moveaxis(
5861
5962
See its docstring for more information.
6063
"""
61-
return Array._new(np.moveaxis(x._array, source, destination))
64+
return Array._new(np.moveaxis(x._array, source, destination), device=x.device)
6265

6366
# Note: The function name is different here (see also matrix_transpose).
6467
# Unlike transpose(), the axes argument is required.
@@ -68,7 +71,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
6871
6972
See its docstring for more information.
7073
"""
71-
return Array._new(np.transpose(x._array, axes))
74+
return Array._new(np.transpose(x._array, axes), device=x.device)
7275

7376
@requires_api_version('2023.12')
7477
def repeat(
@@ -94,7 +97,7 @@ def repeat(
9497
else:
9598
raise TypeError("repeats must be an int or array")
9699

97-
return Array._new(np.repeat(x._array, repeats, axis=axis))
100+
return Array._new(np.repeat(x._array, repeats, axis=axis), device=x.device)
98101

99102
# Note: the optional argument is called 'shape', not 'newshape'
100103
def reshape(x: Array,
@@ -117,7 +120,7 @@ def reshape(x: Array,
117120
if copy is False and not np.shares_memory(data, reshaped):
118121
raise AttributeError("Incompatible shape for in-place modification.")
119122

120-
return Array._new(reshaped)
123+
return Array._new(reshaped, device=x.device)
121124

122125

123126
def roll(
@@ -132,7 +135,7 @@ def roll(
132135
133136
See its docstring for more information.
134137
"""
135-
return Array._new(np.roll(x._array, shift, axis=axis))
138+
return Array._new(np.roll(x._array, shift, axis=axis), device=x.device)
136139

137140

138141
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
@@ -141,7 +144,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
141144
142145
See its docstring for more information.
143146
"""
144-
return Array._new(np.squeeze(x._array, axis=axis))
147+
return Array._new(np.squeeze(x._array, axis=axis), device=x.device)
145148

146149

147150
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
@@ -152,8 +155,10 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
152155
"""
153156
# Call result type here just to raise on disallowed type combinations
154157
result_type(*arrays)
158+
if len({a.device for a in arrays}) > 1:
159+
raise ValueError("concat inputs must all be on the same device")
155160
arrays = tuple(a._array for a in arrays)
156-
return Array._new(np.stack(arrays, axis=axis))
161+
return Array._new(np.stack(arrays, axis=axis), device=arrays[0].device)
157162

158163

159164
@requires_api_version('2023.12')
@@ -166,7 +171,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
166171
# Note: NumPy allows repetitions to be an int or array
167172
if not isinstance(repetitions, tuple):
168173
raise TypeError("repetitions must be a tuple")
169-
return Array._new(np.tile(x._array, repetitions))
174+
return Array._new(np.tile(x._array, repetitions), device=x.device)
170175

171176
# Note: this function is new
172177
@requires_api_version('2023.12')

0 commit comments

Comments
 (0)