Skip to content

Commit e0b2a64

Browse files
committed
Multi-device support in linear algebra functions
1 parent 724e071 commit e0b2a64

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

array_api_strict/_info.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Optional, Union, Tuple, List
77
from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info
88

9-
from ._array_object import CPU_DEVICE, Device
9+
from ._array_object import ALL_DEVICES, CPU_DEVICE
1010
from ._flags import get_array_api_strict_flags, requires_api_version
1111
from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128
1212

@@ -121,7 +121,7 @@ def dtypes(
121121

122122
@requires_api_version('2023.12')
123123
def devices() -> List[device]:
124-
return [CPU_DEVICE, Device("device1"), Device("device2")]
124+
return list(ALL_DEVICES)
125125

126126
__all__ = [
127127
"capabilities",

array_api_strict/_linear_algebra_functions.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def matmul(x1: Array, x2: Array, /) -> Array:
3030
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
3131
raise TypeError('Only numeric dtypes are allowed in matmul')
3232

33-
return Array._new(np.matmul(x1._array, x2._array))
33+
if x1.device != x2.device:
34+
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
35+
36+
return Array._new(np.matmul(x1._array, x2._array), device=x1.device)
3437

3538
# Note: tensordot is the numpy top-level namespace but not in np.linalg
3639

@@ -41,14 +44,17 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
4144
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
4245
raise TypeError('Only numeric dtypes are allowed in tensordot')
4346

44-
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
47+
if x1.device != x2.device:
48+
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
49+
50+
return Array._new(np.tensordot(x1._array, x2._array, axes=axes), device=x1.device)
4551

4652
# Note: this function is new in the array API spec. Unlike transpose, it only
4753
# transposes the last two axes.
4854
def matrix_transpose(x: Array, /) -> Array:
4955
if x.ndim < 2:
5056
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
51-
return Array._new(np.swapaxes(x._array, -1, -2))
57+
return Array._new(np.swapaxes(x._array, -1, -2), device=x.device)
5258

5359
# Note: vecdot is not in NumPy
5460
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
@@ -61,6 +67,9 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
6167
elif axis < min(-1, -x1.ndim, -x2.ndim):
6268
raise ValueError("axis is out of bounds for x1 and x2")
6369

70+
if x1.device != x2.device:
71+
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
72+
6473
# In versions of the standard prior to 2023.12, vecdot applied axis after
6574
# broadcasting. This is different from applying it before broadcasting
6675
# when axis is nonnegative. The below code keeps this behavior for
@@ -78,4 +87,4 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
7887
x2_ = np.moveaxis(x2_, axis, -1)
7988

8089
res = x1_[..., None, :] @ x2_[..., None]
81-
return Array._new(res[..., 0, 0])
90+
return Array._new(res[..., 0, 0], device=x1.device)

0 commit comments

Comments
 (0)