2
2
3
3
from functools import reduce as _reduce , wraps as _wraps
4
4
from builtins import all as _builtin_all , any as _builtin_any
5
- from typing import List , Optional , Sequence , Tuple , Union
5
+ from typing import Any , List , Optional , Sequence , Tuple , Union
6
6
7
7
import torch
8
8
9
9
from .._internal import get_xp
10
10
from ..common import _aliases
11
+ from ..common ._typing import NestedSequence , SupportsBufferProtocol
11
12
from ._info import __array_namespace_info__
12
13
from ._typing import Array , Device , DType
13
14
@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207
208
remainder = _two_arg (torch .remainder )
208
209
subtract = _two_arg (torch .subtract )
209
210
211
+
212
+ def asarray (
213
+ obj : (
214
+ Array
215
+ | bool | int | float | complex
216
+ | NestedSequence [bool | int | float | complex ]
217
+ | SupportsBufferProtocol
218
+ ),
219
+ / ,
220
+ * ,
221
+ dtype : DType | None = None ,
222
+ device : Device | None = None ,
223
+ copy : bool | None = None ,
224
+ ** kwargs : Any ,
225
+ ) -> Array :
226
+ # torch.asarray does not respect input->output device propagation
227
+ # https://github.com/pytorch/pytorch/issues/150199
228
+ if device is None and isinstance (obj , torch .Tensor ):
229
+ device = obj .device
230
+ return torch .asarray (obj , dtype = dtype , device = device , copy = copy , ** kwargs )
231
+
232
+
210
233
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211
234
# of 'axis'.
212
235
@@ -227,6 +250,9 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
227
250
unstack = get_xp (torch )(_aliases .unstack )
228
251
cumulative_sum = get_xp (torch )(_aliases .cumulative_sum )
229
252
cumulative_prod = get_xp (torch )(_aliases .cumulative_prod )
253
+ finfo = get_xp (torch )(_aliases .finfo )
254
+ iinfo = get_xp (torch )(_aliases .iinfo )
255
+
230
256
231
257
# torch.sort also returns a tuple
232
258
# https://github.com/pytorch/pytorch/issues/70921
@@ -282,7 +308,6 @@ def prod(x: Array,
282
308
dtype : Optional [DType ] = None ,
283
309
keepdims : bool = False ,
284
310
** kwargs ) -> Array :
285
- x = torch .asarray (x )
286
311
ndim = x .ndim
287
312
288
313
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +343,6 @@ def sum(x: Array,
318
343
dtype : Optional [DType ] = None ,
319
344
keepdims : bool = False ,
320
345
** kwargs ) -> Array :
321
- x = torch .asarray (x )
322
346
ndim = x .ndim
323
347
324
348
# https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +372,6 @@ def any(x: Array,
348
372
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
349
373
keepdims : bool = False ,
350
374
** kwargs ) -> Array :
351
- x = torch .asarray (x )
352
375
ndim = x .ndim
353
376
if axis == ():
354
377
return x .to (torch .bool )
@@ -373,7 +396,6 @@ def all(x: Array,
373
396
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
374
397
keepdims : bool = False ,
375
398
** kwargs ) -> Array :
376
- x = torch .asarray (x )
377
399
ndim = x .ndim
378
400
if axis == ():
379
401
return x .to (torch .bool )
@@ -816,7 +838,7 @@ def sign(x: Array, /) -> Array:
816
838
return out
817
839
818
840
819
- __all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
841
+ __all__ = ['__array_namespace_info__' , 'asarray' , ' result_type' , 'can_cast' ,
820
842
'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
821
843
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
822
844
'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
@@ -832,6 +854,6 @@ def sign(x: Array, /) -> Array:
832
854
'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
833
855
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
834
856
'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
835
- 'take' , 'take_along_axis' , 'sign' ]
857
+ 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' ]
836
858
837
859
_all_ignore = ['torch' , 'get_xp' ]
0 commit comments