Skip to content

Commit 1cb10db

Browse files
committed
BUG: astype(..., copy=True) doesn't copy on dask
1 parent e5dd419 commit 1cb10db

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

array_api_compat/dask/array/_aliases.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@
3939

4040
isdtype = get_xp(np)(_aliases.isdtype)
4141
unstack = get_xp(da)(_aliases.unstack)
42-
astype = _aliases.astype
42+
43+
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
44+
if not copy and dtype == x.dtype:
45+
return x
46+
# dask astype doesn't respect copy=True so copy
47+
# manually via numpy
48+
x = np.array(x, dtype=dtype, copy=copy)
49+
return da.from_array(x)
4350

4451
# Common aliases
4552

tests/test_array_namespace.py

+13
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import array_api_compat
1111
from array_api_compat import array_namespace
12+
import array_api_compat.numpy
1213

1314
from ._helpers import import_, all_libraries, wrapped_libraries
1415

@@ -22,6 +23,7 @@ def test_array_namespace(library, api_version, use_compat):
2223
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
2324
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2425
return
26+
print(use_compat)
2527
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
2628

2729
if use_compat is False or use_compat is None and library not in wrapped_libraries:
@@ -36,6 +38,17 @@ def test_array_namespace(library, api_version, use_compat):
3638
assert namespace == jax.experimental.array_api
3739
else:
3840
assert namespace == xp
41+
elif use_compat is None:
42+
if library == "dask.array":
43+
# dask should always return wrapped version
44+
# since dask.array is not array API compatible
45+
assert namespace == array_api_compat.dask.array
46+
elif library == "numpy":
47+
assert namespace == array_api_compat.numpy
48+
elif library == "torch":
49+
assert namespace == array_api_compat.torch
50+
else:
51+
assert namespace == xp
3952
else:
4053
if library == "dask.array":
4154
assert namespace == array_api_compat.dask.array

tests/test_common.py

+13
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,16 @@ def test_asarray_copy(library):
272272
assert all(b[0] == 1.0)
273273
else:
274274
assert all(b[0] == 0.0)
275+
276+
@pytest.mark.parametrize("library", wrapped_libraries)
277+
def test_astype_copy(library):
278+
# array-api-tests currently doesn't check copy=True
279+
# makes a copy when dtypes are the same
280+
# so we check that here
281+
xp = import_(library, wrapper=True)
282+
a = xp.asarray([1])
283+
b = xp.astype(a, a.dtype, copy=True)
284+
285+
a[0] = 10
286+
287+
assert b[0] == 1

0 commit comments

Comments
 (0)