@@ -25,8 +25,11 @@ def concat(
25
25
# Note: Casting rules here are different from the np.concatenate default
26
26
# (no for scalars with axis=None, no cross-kind casting)
27
27
dtype = result_type (* arrays )
28
+ if len ({a .device for a in arrays }) > 1 :
29
+ raise ValueError ("concat inputs must all be on the same device" )
30
+
28
31
arrays = tuple (a ._array for a in arrays )
29
- return Array ._new (np .concatenate (arrays , axis = axis , dtype = dtype ._np_dtype ))
32
+ return Array ._new (np .concatenate (arrays , axis = axis , dtype = dtype ._np_dtype ), device = arrays [ 0 ]. device )
30
33
31
34
32
35
def expand_dims (x : Array , / , * , axis : int ) -> Array :
@@ -35,7 +38,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array:
35
38
36
39
See its docstring for more information.
37
40
"""
38
- return Array ._new (np .expand_dims (x ._array , axis ))
41
+ return Array ._new (np .expand_dims (x ._array , axis ), device = x . device )
39
42
40
43
41
44
def flip (x : Array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None ) -> Array :
@@ -44,7 +47,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
44
47
45
48
See its docstring for more information.
46
49
"""
47
- return Array ._new (np .flip (x ._array , axis = axis ))
50
+ return Array ._new (np .flip (x ._array , axis = axis ), device = x . device )
48
51
49
52
@requires_api_version ('2023.12' )
50
53
def moveaxis (
@@ -58,7 +61,7 @@ def moveaxis(
58
61
59
62
See its docstring for more information.
60
63
"""
61
- return Array ._new (np .moveaxis (x ._array , source , destination ))
64
+ return Array ._new (np .moveaxis (x ._array , source , destination ), device = x . device )
62
65
63
66
# Note: The function name is different here (see also matrix_transpose).
64
67
# Unlike transpose(), the axes argument is required.
@@ -68,7 +71,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
68
71
69
72
See its docstring for more information.
70
73
"""
71
- return Array ._new (np .transpose (x ._array , axes ))
74
+ return Array ._new (np .transpose (x ._array , axes ), device = x . device )
72
75
73
76
@requires_api_version ('2023.12' )
74
77
def repeat (
@@ -94,7 +97,7 @@ def repeat(
94
97
else :
95
98
raise TypeError ("repeats must be an int or array" )
96
99
97
- return Array ._new (np .repeat (x ._array , repeats , axis = axis ))
100
+ return Array ._new (np .repeat (x ._array , repeats , axis = axis ), device = x . device )
98
101
99
102
# Note: the optional argument is called 'shape', not 'newshape'
100
103
def reshape (x : Array ,
@@ -117,7 +120,7 @@ def reshape(x: Array,
117
120
if copy is False and not np .shares_memory (data , reshaped ):
118
121
raise AttributeError ("Incompatible shape for in-place modification." )
119
122
120
- return Array ._new (reshaped )
123
+ return Array ._new (reshaped , device = x . device )
121
124
122
125
123
126
def roll (
@@ -132,7 +135,7 @@ def roll(
132
135
133
136
See its docstring for more information.
134
137
"""
135
- return Array ._new (np .roll (x ._array , shift , axis = axis ))
138
+ return Array ._new (np .roll (x ._array , shift , axis = axis ), device = x . device )
136
139
137
140
138
141
def squeeze (x : Array , / , axis : Union [int , Tuple [int , ...]]) -> Array :
@@ -141,7 +144,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
141
144
142
145
See its docstring for more information.
143
146
"""
144
- return Array ._new (np .squeeze (x ._array , axis = axis ))
147
+ return Array ._new (np .squeeze (x ._array , axis = axis ), device = x . device )
145
148
146
149
147
150
def stack (arrays : Union [Tuple [Array , ...], List [Array ]], / , * , axis : int = 0 ) -> Array :
@@ -152,8 +155,10 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
152
155
"""
153
156
# Call result type here just to raise on disallowed type combinations
154
157
result_type (* arrays )
158
+ if len ({a .device for a in arrays }) > 1 :
159
+ raise ValueError ("concat inputs must all be on the same device" )
155
160
arrays = tuple (a ._array for a in arrays )
156
- return Array ._new (np .stack (arrays , axis = axis ))
161
+ return Array ._new (np .stack (arrays , axis = axis ), device = arrays [ 0 ]. device )
157
162
158
163
159
164
@requires_api_version ('2023.12' )
@@ -166,7 +171,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
166
171
# Note: NumPy allows repetitions to be an int or array
167
172
if not isinstance (repetitions , tuple ):
168
173
raise TypeError ("repetitions must be a tuple" )
169
- return Array ._new (np .tile (x ._array , repetitions ))
174
+ return Array ._new (np .tile (x ._array , repetitions ), device = x . device )
170
175
171
176
# Note: this function is new
172
177
@requires_api_version ('2023.12' )
0 commit comments