Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d6a53f9

Browse files
committedApr 21, 2025·
Overhaul test_all
1 parent 5e14b53 commit d6a53f9

File tree

16 files changed

+262
-95
lines changed

16 files changed

+262
-95
lines changed
 

‎array_api_compat/common/_aliases.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,8 +720,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
720720
"finfo",
721721
"iinfo",
722722
]
723-
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
724-
725723

726724
def __dir__() -> list[str]:
727725
return __all__

‎array_api_compat/common/_helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,5 @@ def is_lazy_array(x: object) -> bool:
10421042
"to_device",
10431043
]
10441044

1045-
_all_ignore = ["sys", "math", "inspect", "warnings"]
1046-
10471045
def __dir__() -> list[str]:
10481046
return __all__

‎array_api_compat/common/_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,6 @@ def trace(
225225
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
226226
'trace']
227227

228-
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
229-
230228

231229
def __dir__() -> list[str]:
232230
return __all__

‎array_api_compat/cupy/_aliases.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,3 @@ def count_nonzero(
160160
'atan2', 'atanh', 'bitwise_left_shift',
161161
'bitwise_invert', 'bitwise_right_shift',
162162
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
163-
164-
_all_ignore = ['cp', 'get_xp']

‎array_api_compat/cupy/_typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
__all__ = ["Array", "DType", "Device"]
4-
_all_ignore = ["cp"]
54

65
from typing import TYPE_CHECKING
76

‎array_api_compat/dask/array/_aliases.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,6 @@ def count_nonzero(
369369
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
370370
] # fmt: skip
371371
__all__ += _aliases.__all__
372-
_all_ignore = ["array_namespace", "get_xp", "da", "np"]
373-
374372

375373
def __dir__() -> list[str]:
376374
return __all__

‎array_api_compat/dask/array/fft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@
1818
rfftfreq = get_xp(da)(_fft.rfftfreq)
1919

2020
__all__ = fft_all + ["fftfreq", "rfftfreq"]
21-
_all_ignore = ["da", "fft_all", "get_xp", "warnings"]
21+
22+
def __dir__() -> list[str]:
23+
return __all__

‎array_api_compat/dask/array/linalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,5 @@ def svdvals(x: _Array) -> _Array:
6969
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
7070
"vector_norm", "diagonal"]
7171

72-
_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings']
72+
def __dir__() -> list[str]:
73+
return __all__

‎array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def count_nonzero(
157157
else:
158158
unstack = get_xp(np)(_aliases.unstack)
159159

160-
__all__ = [
160+
__all__ = _aliases.__all__ + [
161161
"__array_namespace_info__",
162162
"asarray",
163163
"astype",
@@ -176,8 +176,6 @@ def count_nonzero(
176176
"count_nonzero",
177177
"pow",
178178
]
179-
__all__ += _aliases.__all__
180-
_all_ignore = ["np", "get_xp"]
181179

182180

183181
def __dir__() -> list[str]:

‎array_api_compat/numpy/_typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Array: TypeAlias = np.ndarray
2424

2525
__all__ = ["Array", "DType", "Device"]
26-
_all_ignore = ["np"]
2726

2827

2928
def __dir__() -> list[str]:

‎array_api_compat/numpy/fft.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from numpy.fft import __all__ as fft_all
32
from numpy.fft import fft2, ifft2, irfft2, rfft2
43

54
from .._internal import get_xp
@@ -21,15 +20,7 @@
2120
ifftshift = get_xp(np)(_fft.ifftshift)
2221

2322

24-
__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
25-
__all__ += _fft.__all__
26-
23+
__all__ = _fft.__all__ + ["rfft2", "irfft2", "fft2", "ifft2"]
2724

2825
def __dir__() -> list[str]:
2926
return __all__
30-
31-
32-
del get_xp
33-
del np
34-
del fft_all
35-
del _fft

‎array_api_compat/numpy/linalg.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def solve(x1: Array, x2: Array, /) -> Array:
120120
vector_norm = get_xp(np)(_linalg.vector_norm)
121121

122122

123-
__all__ = [
123+
__all__ = _linalg.__all__ + [
124124
"LinAlgError",
125125
"cond",
126126
"det",
@@ -132,12 +132,11 @@ def solve(x1: Array, x2: Array, /) -> Array:
132132
"matrix_power",
133133
"multi_dot",
134134
"norm",
135+
"solve",
135136
"tensorinv",
136137
"tensorsolve",
138+
"vector_norm",
137139
]
138-
__all__ += _linalg.__all__
139-
__all__ += ["solve", "vector_norm"]
140-
141140

142141
def __dir__() -> list[str]:
143142
return __all__

‎array_api_compat/torch/_aliases.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -841,5 +841,3 @@ def sign(x: Array, /) -> Array:
841841
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
842842
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
843843
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat']
844-
845-
_all_ignore = ['torch', 'get_xp']

‎array_api_compat/torch/fft.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,3 @@ def ifftshift(
8181
"fftshift",
8282
"ifftshift",
8383
]
84-
85-
_all_ignore = ['torch']

‎array_api_compat/torch/linalg.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,5 @@ def vector_norm(
113113
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
114114
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
115115

116-
_all_ignore = ['torch_linalg', 'sum']
117-
118-
del linalg_all
119-
120116
def __dir__() -> list[str]:
121117
return __all__

‎tests/test_all.py

Lines changed: 252 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,259 @@
1-
"""
2-
Test that files that define __all__ aren't missing any exports.
1+
"""Test exported names"""
32

4-
You can add names that shouldn't be exported to _all_ignore, like
3+
import builtins
54

6-
_all_ignore = ['sys']
5+
import numpy as np
6+
import pytest
77

8-
This is preferable to del-ing the names as this will break any name that is
9-
used inside of a function. Note that names starting with an underscore are automatically ignored.
10-
"""
8+
from ._helpers import wrapped_libraries
119

10+
NAMES = {
11+
"": [
12+
# Inspection
13+
"__array_api_version__",
14+
"__array_namespace_info__",
15+
# Constants
16+
"e",
17+
"inf",
18+
"nan",
19+
"newaxis",
20+
"pi",
21+
# Creation functions
22+
"arange",
23+
"asarray",
24+
"empty",
25+
"empty_like",
26+
"eye",
27+
"from_dlpack",
28+
"full",
29+
"full_like",
30+
"linspace",
31+
"meshgrid",
32+
"ones",
33+
"ones_like",
34+
"tril",
35+
"triu",
36+
"zeros",
37+
"zeros_like",
38+
# Data Type Functions
39+
"astype",
40+
"can_cast",
41+
"finfo",
42+
"iinfo",
43+
"isdtype",
44+
"result_type",
45+
# Data Types
46+
"bool",
47+
"int8",
48+
"int16",
49+
"int32",
50+
"int64",
51+
"uint8",
52+
"uint16",
53+
"uint32",
54+
"uint64",
55+
"float32",
56+
"float64",
57+
"complex64",
58+
"complex128",
59+
# Elementwise Functions
60+
"abs",
61+
"acos",
62+
"acosh",
63+
"add",
64+
"asin",
65+
"asinh",
66+
"atan",
67+
"atan2",
68+
"atanh",
69+
"bitwise_and",
70+
"bitwise_left_shift",
71+
"bitwise_invert",
72+
"bitwise_or",
73+
"bitwise_right_shift",
74+
"bitwise_xor",
75+
"ceil",
76+
"clip",
77+
"conj",
78+
"copysign",
79+
"cos",
80+
"cosh",
81+
"divide",
82+
"equal",
83+
"exp",
84+
"expm1",
85+
"floor",
86+
"floor_divide",
87+
"greater",
88+
"greater_equal",
89+
"hypot",
90+
"imag",
91+
"isfinite",
92+
"isinf",
93+
"isnan",
94+
"less",
95+
"less_equal",
96+
"log",
97+
"log1p",
98+
"log2",
99+
"log10",
100+
"logaddexp",
101+
"logical_and",
102+
"logical_not",
103+
"logical_or",
104+
"logical_xor",
105+
"maximum",
106+
"minimum",
107+
"multiply",
108+
"negative",
109+
"nextafter",
110+
"not_equal",
111+
"positive",
112+
"pow",
113+
"real",
114+
"reciprocal",
115+
"remainder",
116+
"round",
117+
"sign",
118+
"signbit",
119+
"sin",
120+
"sinh",
121+
"square",
122+
"sqrt",
123+
"subtract",
124+
"tan",
125+
"tanh",
126+
"trunc",
127+
# Indexing Functions
128+
"take",
129+
"take_along_axis",
130+
# Linear Algebra Functions
131+
"matmul",
132+
"matrix_transpose",
133+
"tensordot",
134+
"vecdot",
135+
# Manipulation Functions
136+
"broadcast_arrays",
137+
"broadcast_to",
138+
"concat",
139+
"expand_dims",
140+
"flip",
141+
"moveaxis",
142+
"permute_dims",
143+
"repeat",
144+
"reshape",
145+
"roll",
146+
"squeeze",
147+
"stack",
148+
"tile",
149+
"unstack",
150+
# Searching Functions
151+
"argmax",
152+
"argmin",
153+
"count_nonzero",
154+
"nonzero",
155+
"searchsorted",
156+
"where",
157+
# Set functions
158+
"unique_all",
159+
"unique_counts",
160+
"unique_inverse",
161+
"unique_values",
162+
# Sorting Functions
163+
"argsort",
164+
"sort",
165+
# Statistical Functions
166+
"cumulative_prod",
167+
"cumulative_sum",
168+
"max",
169+
"mean",
170+
"min",
171+
"prod",
172+
"std",
173+
"sum",
174+
"var",
175+
# Utility Functions
176+
"all",
177+
"any",
178+
"diff",
179+
],
180+
"fft": [
181+
"fft",
182+
"ifft",
183+
"fftn",
184+
"ifftn",
185+
"rfft",
186+
"irfft",
187+
"rfftn",
188+
"irfftn",
189+
"hfft",
190+
"ihfft",
191+
"fftfreq",
192+
"rfftfreq",
193+
"fftshift",
194+
"ifftshift",
195+
],
196+
"linalg": [
197+
"cholesky",
198+
"cross",
199+
"det",
200+
"diagonal",
201+
"eigh",
202+
"eigvalsh",
203+
"inv",
204+
"matmul",
205+
"matrix_norm",
206+
"matrix_power",
207+
"matrix_rank",
208+
"matrix_transpose",
209+
"outer",
210+
"pinv",
211+
"qr",
212+
"slogdet",
213+
"solve",
214+
"svd",
215+
"svdvals",
216+
"tensordot",
217+
"trace",
218+
"vecdot",
219+
"vector_norm",
220+
],
221+
}
12222

13-
import sys
223+
XFAILS = {
224+
("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [],
225+
("dask.array", ""): ["from_dlpack", "take_along_axis"],
226+
("dask.array", "linalg"): [
227+
"cross",
228+
"det",
229+
"eigh",
230+
"eigvalsh",
231+
"matrix_power",
232+
"pinv",
233+
"slogdet",
234+
],
235+
}
14236

15-
from ._helpers import import_, wrapped_libraries
16237

17-
import pytest
18-
import typing
19-
20-
TYPING_NAMES = frozenset((
21-
"Array",
22-
"Device",
23-
"DType",
24-
"Namespace",
25-
"NestedSequence",
26-
"SupportsBufferProtocol",
27-
))
28-
29-
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
30-
def test_all(library):
31-
if library == "common":
32-
import array_api_compat.common # noqa: F401
33-
else:
34-
import_(library, wrapper=True)
35-
36-
# NB: iterate over a copy to avoid a "dictionary size changed" error
37-
for mod_name in sys.modules.copy():
38-
if not mod_name.startswith('array_api_compat.' + library):
39-
continue
40-
41-
module = sys.modules[mod_name]
42-
43-
# TODO: We should define __all__ in the __init__.py files and test it
44-
# there too.
45-
if not hasattr(module, '__all__'):
46-
continue
47-
48-
dir_names = [n for n in dir(module) if not n.startswith('_')]
49-
if '__array_namespace_info__' in dir(module):
50-
dir_names.append('__array_namespace_info__')
51-
ignore_all_names = set(getattr(module, '_all_ignore', ()))
52-
ignore_all_names |= set(dir(typing))
53-
ignore_all_names |= {"annotations"}
54-
if not module.__name__.endswith("._typing"):
55-
ignore_all_names |= TYPING_NAMES
56-
dir_names = set(dir_names) - set(ignore_all_names)
57-
all_names = module.__all__
58-
59-
if set(dir_names) != set(all_names):
60-
extra_dir = set(dir_names) - set(all_names)
61-
extra_all = set(all_names) - set(dir_names)
62-
assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}"
63-
assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}"
238+
@pytest.mark.parametrize("module", list(NAMES))
239+
@pytest.mark.parametrize("library", wrapped_libraries)
240+
def test_dir(library, module):
241+
"""Test that dir() isn't missing any exports."""
242+
xp = pytest.importorskip(f"array_api_compat.{library}")
243+
mod = getattr(xp, module) if module else xp
244+
missing = set(NAMES[module]) - set(dir(mod))
245+
xfail = set(XFAILS.get((library, module), []))
246+
xpass = xfail - missing
247+
fails = missing - xfail
248+
assert not xpass, "Names in XFAILS are defined: %s" % xpass
249+
assert not fails, "Missing exports: %s" % fails
250+
251+
252+
@pytest.mark.parametrize(
253+
"name", [name for name in NAMES[""] if hasattr(builtins, name)]
254+
)
255+
@pytest.mark.parametrize("library", wrapped_libraries)
256+
def test_builtins_collision(library, name):
257+
"""Test that xp.bool is not accidentally builtins.bool, etc."""
258+
xp = pytest.importorskip(f"array_api_compat.{library}")
259+
assert getattr(xp, name) is not getattr(builtins, name)

0 commit comments

Comments
 (0)
Please sign in to comment.