Skip to content

Commit 84bf725

Browse files
ENH: implement at (#53)
* ENH: add new function `at` * MAINT: use released array-api-compat * update lock-file * Update dependencies * Add xpx namespace in documentation * Change copy to default to None * raise on incompatible cast * Update tests/test_at.py --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent 397e243 commit 84bf725

File tree

12 files changed

+7486
-1103
lines changed

12 files changed

+7486
-1103
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
strategy:
4949
fail-fast: false
5050
matrix:
51-
environment: [ci-py310, ci-py313]
51+
environment: [ci-py310, ci-py313, ci-backends]
5252
runs-on: [ubuntu-latest]
5353

5454
steps:

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
56+
"jax": ("https://jax.readthedocs.io/en/latest", None),
5657
}
5758

5859
nitpick_ignore = [

pixi.lock

+6,926-1,087
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+39-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = ["array-api-compat>=1.1.1"]
29+
dependencies = ["array-api-compat>=1.10.0,<2"]
3030

3131
[project.optional-dependencies]
3232
tests = [
@@ -62,8 +62,8 @@ channels = ["https://prefix.dev/conda-forge"]
6262
platforms = ["linux-64", "osx-arm64", "win-64"]
6363

6464
[tool.pixi.dependencies]
65-
python = ">=3.10.15,<3.14"
66-
array-api-compat = ">=1.1.1"
65+
python = ">=3.10,<3.14"
66+
array-api-compat = ">=1.10.0,<2"
6767

6868
[tool.pixi.pypi-dependencies]
6969
array-api-extra = { path = ".", editable = true }
@@ -130,6 +130,35 @@ python = "~=3.10.0"
130130
[tool.pixi.feature.py313.dependencies]
131131
python = "~=3.13.0"
132132

133+
# Backends that can run on CPU-only hosts
134+
[tool.pixi.feature.backends.target.linux-64.dependencies]
135+
pytorch = "*"
136+
dask = "*"
137+
sparse = ">=0.15"
138+
jax = "*"
139+
140+
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
141+
pytorch = "*"
142+
dask = "*"
143+
sparse = ">=0.15"
144+
jax = "*"
145+
146+
[tool.pixi.feature.backends.target.win-64.dependencies]
147+
# pytorch = "*" # Package unavailable on Windows
148+
dask = "*"
149+
sparse = ">=0.15"
150+
# jax = "*" # Package unavailable on Windows
151+
152+
# Backends that require a GPU host and a CUDA driver
153+
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
154+
cupy = "*"
155+
156+
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
157+
# cupy = "*" # Package unavailable on macOSX
158+
159+
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
160+
cupy = "*"
161+
133162
[tool.pixi.environments]
134163
default = { solve-group = "default" }
135164
lint = { features = ["lint"], solve-group = "default" }
@@ -138,7 +167,9 @@ docs = { features = ["docs"], solve-group = "default" }
138167
dev = { features = ["lint", "tests", "docs", "dev"], solve-group = "default" }
139168
ci-py310 = ["py310", "tests"]
140169
ci-py313 = ["py313", "tests"]
141-
170+
# CUDA not available on free github actions
171+
ci-backends = ["py310", "tests", "backends"]
172+
tests-backends = ["py310", "tests", "backends", "cuda-backends"]
142173

143174
# pytest
144175

@@ -195,6 +226,8 @@ reportAny = false
195226
reportExplicitAny = false
196227
# data-apis/array-api-strict#6
197228
reportUnknownMemberType = false
229+
# no array-api-compat type stubs
230+
reportUnknownVariableType = false
198231

199232

200233
# Ruff
@@ -236,6 +269,7 @@ ignore = [
236269
"PLR09", # Too many <...>
237270
"PLR2004", # Magic value used in comparison
238271
"ISC001", # Conflicts with formatter
272+
"N801", # Class name should use CapWords convention
239273
"N802", # Function name should be lowercase
240274
"N806", # Variable in function should be lowercase
241275
]
@@ -271,6 +305,7 @@ checks = [
271305
"ES01",
272306
]
273307
exclude = [ # don't report on objects that match any of these regex
308+
'.*test_at.*',
274309
'.*test_funcs.*',
275310
'.*test_utils.*',
276311
'.*test_version.*',

src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Extra array functions built on top of the array API standard."""
22

33
from ._funcs import (
4+
at,
45
atleast_nd,
56
cov,
67
create_diagonal,
@@ -16,6 +17,7 @@
1617
# pylint: disable=duplicate-code
1718
__all__ = [
1819
"__version__",
20+
"at",
1921
"atleast_nd",
2022
"cov",
2123
"create_diagonal",

0 commit comments

Comments
 (0)