2
2
3
3
from functools import reduce as _reduce , wraps as _wraps
4
4
from builtins import all as _builtin_all , any as _builtin_any
5
- from typing import Any , List , Optional , Sequence , Tuple , Union
5
+ from typing import Any , List , Optional , Sequence , Tuple , Union , Literal
6
6
7
7
import torch
8
8
@@ -828,6 +828,12 @@ def sign(x: Array, /) -> Array:
828
828
return out
829
829
830
830
831
+ def meshgrid (* arrays : Array , indexing : Literal ['xy' , 'ij' ] = 'xy' ) -> List [Array ]:
832
+ # enforce the default of 'xy'
833
+ # TODO: is the return type a list or a tuple
834
+ return list (torch .meshgrid (* arrays , indexing = 'xy' ))
835
+
836
+
831
837
__all__ = ['__array_namespace_info__' , 'asarray' , 'result_type' , 'can_cast' ,
832
838
'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
833
839
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
@@ -844,6 +850,6 @@ def sign(x: Array, /) -> Array:
844
850
'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
845
851
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
846
852
'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
847
- 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' ]
853
+ 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' , 'meshgrid' ]
848
854
849
855
_all_ignore = ['torch' , 'get_xp' ]
0 commit comments