Skip to content

Commit 0a14d6c

Browse files
authored
Merge pull request #318 from ev-br/count_nonzero_torch_tuple_axis_
BUG: torch: fix count_nonzero with axis tuple and keepdims
2 parents 2adea00 + 8c62443 commit 0a14d6c

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

array_api_compat/torch/_aliases.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,12 @@ def count_nonzero(
548548
) -> Array:
549549
result = torch.count_nonzero(x, dim=axis)
550550
if keepdims:
551-
if axis is not None:
551+
if isinstance(axis, int):
552552
return result.unsqueeze(axis)
553+
elif isinstance(axis, tuple):
554+
n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
555+
sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
556+
return torch.reshape(result, sh)
553557
return _axis_none_keepdims(result, x.ndim, keepdims)
554558
else:
555559
return result

numpy-1-22-xfails.txt

+7
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc
127127
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
128128
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
129129
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract]
130+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
131+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
132+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply]
133+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
134+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign]
135+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
136+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
130137

131138
array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
132139

0 commit comments

Comments
 (0)