Skip to content

Commit 77400d0

Browse files
committed
TST: adapt tests for the lack of __array__
1 parent 1914bbf commit 77400d0

File tree

3 files changed

+15
-34
lines changed

3 files changed

+15
-34
lines changed

array_api_strict/_creation_functions.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from contextlib import contextmanager
43
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
54

65
if TYPE_CHECKING:
@@ -16,19 +15,6 @@
1615

1716
import numpy as np
1817

19-
@contextmanager
20-
def allow_array():
21-
"""
22-
Temporarily enable Array.__array__. This is needed for np.array to parse
23-
list of lists of Array objects.
24-
"""
25-
from . import _array_object
26-
original_value = _array_object._allow_array
27-
try:
28-
_array_object._allow_array = True
29-
yield
30-
finally:
31-
_array_object._allow_array = original_value
3218

3319
def _check_valid_dtype(dtype):
3420
# Note: Only spelling dtypes as the dtype objects is supported.
@@ -112,8 +98,8 @@ def asarray(
11298
# Give a better error message in this case. NumPy would convert this
11399
# to an object array. TODO: This won't handle large integers in lists.
114100
raise OverflowError("Integer out of bounds for array dtypes")
115-
with allow_array():
116-
res = np.array(obj, dtype=_np_dtype, copy=copy)
101+
102+
res = np.array(obj, dtype=_np_dtype, copy=copy)
117103
return Array._new(res, device=device)
118104

119105

array_api_strict/tests/test_array_object.py

+6-17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import operator
23
from builtins import all as all_
34

@@ -351,6 +352,10 @@ def test_array_properties():
351352
assert b.mT.shape == (3, 2)
352353

353354

355+
@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312,
356+
reason="array conversion relies on buffer protocol, and "
357+
"requires python >= 3.12"
358+
)
354359
def test_array_conversion():
355360
# Check that arrays on the CPU device can be converted to NumPy
356361
# but arrays on other devices can't. Note this is testing the logic in
@@ -361,25 +366,9 @@ def test_array_conversion():
361366

362367
for device in ("device1", "device2"):
363368
a = ones((2, 3), device=array_api_strict.Device(device))
364-
with pytest.raises(RuntimeError, match="Can not convert array"):
369+
with pytest.raises((RuntimeError, TypeError)):
365370
asarray([a])
366371

367-
def test__array__():
368-
# __array__ should work for now
369-
a = ones((2, 3))
370-
np.array(a)
371-
372-
# Test the _allow_array private global flag for disabling it in the
373-
# future.
374-
from .. import _array_object
375-
original_value = _array_object._allow_array
376-
try:
377-
_array_object._allow_array = False
378-
a = ones((2, 3))
379-
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
380-
np.array(a)
381-
finally:
382-
_array_object._allow_array = original_value
383372

384373
def test_allow_newaxis():
385374
a = ones(5)

array_api_strict/tests/test_creation_functions.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import warnings
23

34
from numpy.testing import assert_raises
@@ -97,7 +98,12 @@ def test_asarray_copy():
9798
a[0] = 0
9899
assert all(b[0] == 0)
99100

100-
def test_asarray_list_of_lists():
101+
102+
@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312,
103+
reason="array conversion relies on buffer protocol, and "
104+
"requires python >= 3.12"
105+
)
106+
def test_asarray_list_of_arrays():
101107
a = asarray(1, dtype=int16)
102108
b = asarray([1], dtype=int16)
103109
res = asarray([a, a])

0 commit comments

Comments
 (0)