@@ -30,7 +30,10 @@ def matmul(x1: Array, x2: Array, /) -> Array:
30
30
if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
31
31
raise TypeError ('Only numeric dtypes are allowed in matmul' )
32
32
33
- return Array ._new (np .matmul (x1 ._array , x2 ._array ))
33
+ if x1 .device != x2 .device :
34
+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
35
+
36
+ return Array ._new (np .matmul (x1 ._array , x2 ._array ), device = x1 .device )
34
37
35
38
# Note: tensordot is the numpy top-level namespace but not in np.linalg
36
39
@@ -41,14 +44,17 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
41
44
if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
42
45
raise TypeError ('Only numeric dtypes are allowed in tensordot' )
43
46
44
- return Array ._new (np .tensordot (x1 ._array , x2 ._array , axes = axes ))
47
+ if x1 .device != x2 .device :
48
+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
49
+
50
+ return Array ._new (np .tensordot (x1 ._array , x2 ._array , axes = axes ), device = x1 .device )
45
51
46
52
# Note: this function is new in the array API spec. Unlike transpose, it only
47
53
# transposes the last two axes.
48
54
def matrix_transpose (x : Array , / ) -> Array :
49
55
if x .ndim < 2 :
50
56
raise ValueError ("x must be at least 2-dimensional for matrix_transpose" )
51
- return Array ._new (np .swapaxes (x ._array , - 1 , - 2 ))
57
+ return Array ._new (np .swapaxes (x ._array , - 1 , - 2 ), device = x . device )
52
58
53
59
# Note: vecdot is not in NumPy
54
60
def vecdot (x1 : Array , x2 : Array , / , * , axis : int = - 1 ) -> Array :
@@ -61,6 +67,9 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
61
67
elif axis < min (- 1 , - x1 .ndim , - x2 .ndim ):
62
68
raise ValueError ("axis is out of bounds for x1 and x2" )
63
69
70
+ if x1 .device != x2 .device :
71
+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
72
+
64
73
# In versions of the standard prior to 2023.12, vecdot applied axis after
65
74
# broadcasting. This is different from applying it before broadcasting
66
75
# when axis is nonnegative. The below code keeps this behavior for
@@ -78,4 +87,4 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
78
87
x2_ = np .moveaxis (x2_ , axis , - 1 )
79
88
80
89
res = x1_ [..., None , :] @ x2_ [..., None ]
81
- return Array ._new (res [..., 0 , 0 ])
90
+ return Array ._new (res [..., 0 , 0 ], device = x1 . device )
0 commit comments