Skip to content

Commit ede45af

Browse files
committed
MAINT: make __array__ raise on python < 3.12
Otherwise, on python 3.11 and below, np.array(array_api_strict_array) becomes a 0D object array.
1 parent 77400d0 commit ede45af

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

array_api_strict/_array_object.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,18 @@ def __repr__(self: Array, /) -> str:
158158
# Instead of `__array__` we now implement the buffer protocol.
159159
# Note that it makes array-apis-strict requiring python>=3.12
160160
def __buffer__(self, flags):
161-
print('__buffer__')
162161
if self._device != CPU_DEVICE:
163162
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
164163
return memoryview(self._array)
165164
def __release_buffer(self, buffer):
166-
print('__release__')
167165
# XXX anything to do here?
166+
pass
167+
168+
def __array__(self, *args, **kwds):
169+
# a stub for python < 3.12; otherwise numpy silently produces object arrays
170+
raise TypeError(
171+
"Interoperation with NumPy requires python >= 3.12. Please upgrade."
172+
)
168173

169174
# These are various helper functions to make the array behavior match the
170175
# spec in places where it either deviates from or is more strict than

array_api_strict/tests/test_array_object.py

+14
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,20 @@ def test_array_conversion():
369369
with pytest.raises((RuntimeError, TypeError)):
370370
asarray([a])
371371

372+
# __buffer__ should work for now for conversion to numpy
373+
a = ones((2, 3))
374+
na = np.array(a)
375+
assert na.shape == (2, 3)
376+
assert na.dtype == np.float64
377+
378+
@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312,
379+
reason="conversion to numpy errors out unless python >= 3.12"
380+
)
381+
def test_array_conversion_2():
382+
a = ones((2, 3))
383+
with pytest.raises(TypeError):
384+
np.array(a)
385+
372386

373387
def test_allow_newaxis():
374388
a = ones(5)

0 commit comments

Comments
 (0)