Skip to content

Commit 2b5e289

Browse files
authored
Merge pull request #322 from ev-br/torch_meshgrid
BUG: torch: meshgrid defaults to indexing="xy"
2 parents 0a14d6c + 5cf5d8f commit 2b5e289

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

array_api_compat/torch/_aliases.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
5-
from typing import Any, List, Optional, Sequence, Tuple, Union
5+
from typing import Any, List, Optional, Sequence, Tuple, Union, Literal
66

77
import torch
88

@@ -828,6 +828,12 @@ def sign(x: Array, /) -> Array:
828828
return out
829829

830830

831+
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]:
832+
# enforce the default of 'xy'
833+
# TODO: is the return type a list or a tuple
834+
return list(torch.meshgrid(*arrays, indexing='xy'))
835+
836+
831837
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
832838
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
833839
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
@@ -844,6 +850,6 @@ def sign(x: Array, /) -> Array:
844850
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
845851
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
846852
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
847-
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat']
853+
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid']
848854

849855
_all_ignore = ['torch', 'get_xp']

tests/test_torch.py

+17
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,20 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b):
100100
assert dtype_1 == dtype_2
101101
finally:
102102
torch.set_default_dtype(prev_default)
103+
104+
105+
def test_meshgrid():
106+
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
107+
108+
x, y = xp.asarray([1, 2]), xp.asarray([4])
109+
110+
X, Y = xp.meshgrid(x, y)
111+
112+
# output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different
113+
X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]])
114+
115+
assert X.shape == X_xy.shape
116+
assert xp.all(X == X_xy)
117+
118+
assert Y.shape == Y_xy.shape
119+
assert xp.all(Y == Y_xy)

0 commit comments

Comments
 (0)