Skip to content

Commit 1acba0c

Browse files
authored
BUG: take_along_axis: add numpy and cupy aliases, skip testing on dask (#317)
1 parent e600449 commit 1acba0c

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

array_api_compat/cupy/_aliases.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ def count_nonzero(
124124
return result
125125

126126

127+
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
128+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
129+
return cp.take_along_axis(x, indices, axis=axis)
130+
131+
127132
# These functions are completely new here. If the library already has them
128133
# (i.e., numpy 2.0), use the library version instead of our wrapper.
129134
if hasattr(cp, 'vecdot'):
@@ -145,6 +150,7 @@ def count_nonzero(
145150
'acos', 'acosh', 'asin', 'asinh', 'atan',
146151
'atan2', 'atanh', 'bitwise_left_shift',
147152
'bitwise_invert', 'bitwise_right_shift',
148-
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
153+
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
154+
'take_along_axis']
149155

150156
_all_ignore = ['cp', 'get_xp']

array_api_compat/numpy/_aliases.py

+6
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def count_nonzero(
140140
return result
141141

142142

143+
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
144+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
145+
return np.take_along_axis(x, indices, axis=axis)
146+
147+
143148
# These functions are completely new here. If the library already has them
144149
# (i.e., numpy 2.0), use the library version instead of our wrapper.
145150
if hasattr(np, "vecdot"):
@@ -175,6 +180,7 @@ def count_nonzero(
175180
"concat",
176181
"count_nonzero",
177182
"pow",
183+
"take_along_axis"
178184
]
179185
__all__ += _aliases.__all__
180186
_all_ignore = ["np", "get_xp"]

dask-xfails.txt

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace
2424
# Shape mismatch
2525
array_api_tests/test_indexing_functions.py::test_take
2626

27+
# missing `take_along_axis`, https://github.com/dask/dask/issues/3663
28+
array_api_tests/test_indexing_functions.py::test_take_along_axis
29+
2730
# Array methods and attributes not already on da.Array cannot be wrapped
2831
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
2932
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]

0 commit comments

Comments
 (0)