Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ ignore = [
"FIX", # flake8-fixme
"ISC001", # Conflicts with formatter
"PYI041", # Use `float` instead of `int | float`
"TD002", # Missing author in TODO
"TD003", # Missing issue link for this TODO
]

[tool.ruff.lint.pydocstyle]
Expand Down
18 changes: 17 additions & 1 deletion src/array_api_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,25 @@
"Array",
"HasArrayNamespace",
"HasDType",
"HasDevice",
"HasMatrixTranspose",
"HasNDim",
"HasShape",
"HasSize",
"HasTranspose",
"__version__",
"__version_tuple__",
)

from ._array import Array, HasArrayNamespace, HasDType
from ._array import (
Array,
HasArrayNamespace,
HasDevice,
HasDType,
HasMatrixTranspose,
HasNDim,
HasShape,
HasSize,
HasTranspose,
)
from ._version import version as __version__, version_tuple as __version_tuple__
143 changes: 138 additions & 5 deletions src/array_api_typing/_array.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
__all__ = (
"Array",
"HasArrayNamespace",
"HasDType",
"HasDevice",
"HasMatrixTranspose",
"HasNDim",
"HasShape",
"HasSize",
"HasTranspose",
)

from types import ModuleType
from typing import Literal, Protocol
from typing import Literal, Protocol, Self
from typing_extensions import TypeVar

NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
DeviceT_co = TypeVar("DeviceT_co", covariant=True, default=object)


class HasArrayNamespace(Protocol[NamespaceT_co]):
Expand Down Expand Up @@ -67,19 +75,144 @@ def dtype(self, /) -> DTypeT_co:
...


class HasDevice(Protocol[DeviceT_co]):
"""Protocol for array classes that have a device attribute."""

@property
def device(self) -> DeviceT_co:
"""Hardware device the array data resides on."""
...


class HasMatrixTranspose(Protocol):
"""Protocol for array classes that have a matrix transpose attribute."""

@property
def mT(self) -> Self: # noqa: N802
"""Transpose of a matrix (or a stack of matrices).

If an array instance has fewer than two dimensions, an error should be
raised.

Returns:
Self: array whose last two dimensions (axes) are permuted in reverse
order relative to original array (i.e., for an array instance
having shape `(..., M, N)`, the returned array must have shape
`(..., N, M))`. The returned array must have the same data type
as the original array.

"""
...


class HasNDim(Protocol):
"""Protocol for array classes that have a number of dimensions attribute."""

@property
def ndim(self) -> int:
"""Number of array dimensions (axes).

Returns:
int: number of array dimensions (axes).

"""
...


class HasShape(Protocol):
"""Protocol for array classes that have a shape attribute."""

@property
def shape(self) -> tuple[int | None, ...]:
"""Shape of the array.

Returns:
tuple[int | None, ...]: array dimensions. An array dimension must be None
if and only if a dimension is unknown.

Notes:
For array libraries having graph-based computational models, array
dimensions may be unknown due to data-dependent operations (e.g.,
boolean indexing; `A[:, B > 0]`) and thus cannot be statically
resolved without knowing array contents.

"""
...


class HasSize(Protocol):
"""Protocol for array classes that have a size attribute."""

@property
def size(self) -> int | None:
"""Number of elements in an array.

Returns:
int | None: number of elements in an array. The returned value must
be `None` if and only if one or more array dimensions are
unknown.

Notes:
This must equal the product of the array's dimensions.

"""
...


class HasTranspose(Protocol):
"""Protocol for array classes that support the transpose operation."""

@property
def T(self) -> Self: # noqa: N802
"""Transpose of the array.

The array instance must be two-dimensional. If the array instance is not
two-dimensional, an error should be raised.

Returns:
Self: two-dimensional array whose first and last dimensions (axes)
are permuted in reverse order relative to original array. The
returned array must have the same data type as the original
array.

Notes:
Limiting the transpose to two-dimensional arrays (matrices) deviates
from the NumPy et al practice of reversing all axes for arrays
having more than two-dimensions. This is intentional, as reversing
all axes was found to be problematic (e.g., conflicting with the
mathematical definition of a transpose which is limited to matrices;
not operating on batches of matrices; et cetera). In order to
reverse all axes, one is recommended to use the functional
`PermuteDims` interface found in this specification.

"""
...


class Array(
HasArrayNamespace[NamespaceT_co],
# ------ Attributes -------
HasDType[DTypeT_co],
HasDevice[DeviceT_co],
HasMatrixTranspose,
HasNDim,
HasShape,
HasSize,
HasTranspose,
# ------- Methods ---------
HasArrayNamespace[NamespaceT_co],
# -------------------------
Protocol[DTypeT_co, NamespaceT_co],
Protocol[DTypeT_co, DeviceT_co, NamespaceT_co],
):
"""Array API specification for array object attributes and methods.

The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
NamespaceT]`` where:
The type is: ``Array[+DTypeT, +DeviceT = object, +NamespaceT = ModuleType] =
Array[DTypeT, DeviceT, NamespaceT]`` where:

- `DTypeT` is the data type of the array elements.
- `DeviceT` is the type of the device attribute. It defaults to `object` to
enable skipping device specification. Array objects supporting device
management can specify a more specific type if they use types (as opposed
to object instances) to distinguish between different devices.
- `NamespaceT` is the type of the array namespace. It defaults to
`ModuleType`, which is the most common form of array namespace (e.g.,
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
Expand Down
36 changes: 32 additions & 4 deletions tests/integration/test_numpy1p0.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# mypy: disable-error-code="no-redef"

from types import ModuleType
from typing import Any
from typing import Any, assert_type

import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
from numpy import dtype
Expand Down Expand Up @@ -43,11 +43,39 @@ _: xpt.HasDType[dtype[Any]] = nparr_f32
# `xpt.Array`

# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, ModuleType] = nparr
a_ns: xpt.Array[Any, Any, ModuleType] = nparr

# Check DTypeT_co assignment
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
# type annotate specific dtypes like `np.float32` or `np.int32`.
_: xpt.Array[dtype[Any]] = nparr
_: xpt.Array[dtype[Any]] = nparr_i32
_: xpt.Array[dtype[Any]] = nparr_f32
x_f32: xpt.Array[dtype[Any]] = nparr_f32
x_i32: xpt.Array[dtype[Any]] = nparr_i32

# Check Attribute `.dtype`
assert_type(x_f32.dtype, dtype[Any])
assert_type(x_i32.dtype, dtype[Any])

# Check Attribute `.device`
assert_type(x_f32.device, object)
assert_type(x_i32.device, object)

# Check Attribute `.mT`
assert_type(x_f32.mT, xpt.Array[dtype[Any]])
assert_type(x_i32.mT, xpt.Array[dtype[Any]])

# Check Attribute `.ndim`
assert_type(x_f32.ndim, int)
assert_type(x_i32.ndim, int)

# Check Attribute `.shape`
assert_type(x_f32.shape, tuple[int | None, ...])
assert_type(x_i32.shape, tuple[int | None, ...])

# Check Attribute `.size`
assert_type(x_f32.size, int | None)
assert_type(x_i32.size, int | None)

# Check Attribute `.T`
assert_type(x_f32.T, xpt.Array[dtype[Any]])
assert_type(x_i32.T, xpt.Array[dtype[Any]])
57 changes: 50 additions & 7 deletions tests/integration/test_numpy2p0.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# mypy: disable-error-code="no-redef"

from types import ModuleType
from typing import Any, TypeAlias
from typing import Any, TypeAlias, assert_type

import numpy as np
import numpy.typing as npt
Expand All @@ -11,12 +11,13 @@ import array_api_typing as xpt
# DType aliases
F32: TypeAlias = np.float32
I32: TypeAlias = np.int32
B: TypeAlias = np.bool_

# Define NDArrays against which we can test the protocols
nparr: npt.NDArray[Any]
nparr_i32: npt.NDArray[I32]
nparr_f32: npt.NDArray[F32]
nparr_b: npt.NDArray[np.bool_]
nparr_b: npt.NDArray[B]

# =========================================================
# `xpt.HasArrayNamespace`
Expand All @@ -42,16 +43,58 @@ _: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
_: xpt.HasDType[Any] = nparr
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
_: xpt.HasDType[np.dtype[B]] = nparr_b

# =========================================================
# `xpt.Array`

# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, ModuleType] = nparr
a_ns: xpt.Array[Any, Any, ModuleType] = nparr

# Check DTypeT_co assignment
_: xpt.Array[Any] = nparr
_: xpt.Array[np.dtype[I32]] = nparr_i32
_: xpt.Array[np.dtype[F32]] = nparr_f32
_: xpt.Array[np.dtype[np.bool_]] = nparr_b
x_f32: xpt.Array[np.dtype[F32]] = nparr_f32
x_i32: xpt.Array[np.dtype[I32]] = nparr_i32
x_b: xpt.Array[np.dtype[B]] = nparr_b

# Check Attribute `.dtype`
assert_type(x_f32.dtype, np.dtype[F32])
assert_type(x_i32.dtype, np.dtype[I32])
assert_type(x_b.dtype, np.dtype[B])

# Check DeviceT_co assignment
x_gooddevice: xpt.Array[Any, object, Any] = nparr
assert_type(x_gooddevice.device, object)

x_baddevice: xpt.Array[Any, int, Any] = nparr # type: ignore[assignment]
_: int = x_baddevice.device

# Check Attribute `.device`
assert_type(x_f32.device, object)
assert_type(x_i32.device, object)
assert_type(x_b.device, object)

# Check Attribute `.mT`
assert_type(x_f32.mT, xpt.Array[np.dtype[F32]])
assert_type(x_i32.mT, xpt.Array[np.dtype[I32]])
assert_type(x_b.mT, xpt.Array[np.dtype[B]])

# Check Attribute `.ndim`
assert_type(x_f32.ndim, int)
assert_type(x_i32.ndim, int)
assert_type(x_b.ndim, int)

# Check Attribute `.shape`
assert_type(x_f32.shape, tuple[int | None, ...])
assert_type(x_i32.shape, tuple[int | None, ...])
assert_type(x_b.shape, tuple[int | None, ...])

# Check Attribute `.size`
assert_type(x_f32.size, int | None)
assert_type(x_i32.size, int | None)
assert_type(x_b.size, int | None)

# Check Attribute `.T`
assert_type(x_f32.T, xpt.Array[np.dtype[F32]])
assert_type(x_i32.T, xpt.Array[np.dtype[I32]])
assert_type(x_b.T, xpt.Array[np.dtype[B]])
Loading
Loading