@@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
6
6
from contextlib import contextmanager
7
7
from enum import Enum
8
8
from types import TracebackType
9
- from typing import Any , Generic , NoReturn , TypeVar , overload
9
+ from typing import Any , Generic , Literal , NoReturn , TypeVar , overload
10
10
from typing_extensions import ParamSpec , Self
11
11
12
12
from google .protobuf .message import Message
@@ -20,7 +20,17 @@ from tensorflow import (
20
20
math as math ,
21
21
types as types ,
22
22
)
23
- from tensorflow ._aliases import AnyArray , DTypeLike , ShapeLike , Slice , TensorCompatible
23
+ from tensorflow ._aliases import (
24
+ AnyArray ,
25
+ DTypeLike ,
26
+ IntArray ,
27
+ ScalarTensorCompatible ,
28
+ ShapeLike ,
29
+ Slice ,
30
+ SparseTensorCompatible ,
31
+ TensorCompatible ,
32
+ UIntTensorCompatible ,
33
+ )
24
34
from tensorflow .autodiff import GradientTape as GradientTape
25
35
from tensorflow .core .protobuf import struct_pb2
26
36
from tensorflow .dtypes import *
@@ -56,6 +66,7 @@ from tensorflow.math import (
56
66
reduce_min as reduce_min ,
57
67
reduce_prod as reduce_prod ,
58
68
reduce_sum as reduce_sum ,
69
+ round as round ,
59
70
sigmoid as sigmoid ,
60
71
sign as sign ,
61
72
sin as sin ,
@@ -403,4 +414,22 @@ def ones_like(
403
414
input : RaggedTensor , dtype : DTypeLike | None = None , name : str | None = None , layout : Layout | None = None
404
415
) -> RaggedTensor : ...
405
416
def reshape (tensor : TensorCompatible , shape : ShapeLike | Tensor , name : str | None = None ) -> Tensor : ...
417
+ def pad (
418
+ tensor : TensorCompatible ,
419
+ paddings : Tensor | IntArray | Iterable [Iterable [int ]],
420
+ mode : Literal ["CONSTANT" , "constant" , "REFLECT" , "reflect" , "SYMMETRIC" , "symmectric" ] = "CONSTANT" ,
421
+ constant_values : ScalarTensorCompatible = 0 ,
422
+ name : str | None = None ,
423
+ ) -> Tensor : ...
424
+ def shape (input : SparseTensorCompatible , out_type : DTypeLike | None = None , name : str | None = None ) -> Tensor : ...
425
+ def where (
426
+ condition : TensorCompatible , x : TensorCompatible | None = None , y : TensorCompatible | None = None , name : str | None = None
427
+ ) -> Tensor : ...
428
+ def gather_nd (
429
+ params : TensorCompatible ,
430
+ indices : UIntTensorCompatible ,
431
+ batch_dims : UIntTensorCompatible = 0 ,
432
+ name : str | None = None ,
433
+ bad_indices_policy : Literal ["" , "DEFAULT" , "ERROR" , "IGNORE" ] = "" ,
434
+ ) -> Tensor : ...
406
435
def __getattr__ (name : str ) -> Incomplete : ...
0 commit comments