Skip to content

Commit a1c185b

Browse files
authored
tensorflow: add a few TensorFlow functions (#13364)
1 parent 84c78c6 commit a1c185b

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

stubs/tensorflow/tensorflow/__init__.pyi

+31-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
66
from contextlib import contextmanager
77
from enum import Enum
88
from types import TracebackType
9-
from typing import Any, Generic, NoReturn, TypeVar, overload
9+
from typing import Any, Generic, Literal, NoReturn, TypeVar, overload
1010
from typing_extensions import ParamSpec, Self
1111

1212
from google.protobuf.message import Message
@@ -20,7 +20,17 @@ from tensorflow import (
2020
math as math,
2121
types as types,
2222
)
23-
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
23+
from tensorflow._aliases import (
24+
AnyArray,
25+
DTypeLike,
26+
IntArray,
27+
ScalarTensorCompatible,
28+
ShapeLike,
29+
Slice,
30+
SparseTensorCompatible,
31+
TensorCompatible,
32+
UIntTensorCompatible,
33+
)
2434
from tensorflow.autodiff import GradientTape as GradientTape
2535
from tensorflow.core.protobuf import struct_pb2
2636
from tensorflow.dtypes import *
@@ -56,6 +66,7 @@ from tensorflow.math import (
5666
reduce_min as reduce_min,
5767
reduce_prod as reduce_prod,
5868
reduce_sum as reduce_sum,
69+
round as round,
5970
sigmoid as sigmoid,
6071
sign as sign,
6172
sin as sin,
@@ -403,4 +414,22 @@ def ones_like(
403414
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
404415
) -> RaggedTensor: ...
405416
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
417+
def pad(
418+
tensor: TensorCompatible,
419+
paddings: Tensor | IntArray | Iterable[Iterable[int]],
420+
mode: Literal["CONSTANT", "constant", "REFLECT", "reflect", "SYMMETRIC", "symmectric"] = "CONSTANT",
421+
constant_values: ScalarTensorCompatible = 0,
422+
name: str | None = None,
423+
) -> Tensor: ...
424+
def shape(input: SparseTensorCompatible, out_type: DTypeLike | None = None, name: str | None = None) -> Tensor: ...
425+
def where(
426+
condition: TensorCompatible, x: TensorCompatible | None = None, y: TensorCompatible | None = None, name: str | None = None
427+
) -> Tensor: ...
428+
def gather_nd(
429+
params: TensorCompatible,
430+
indices: UIntTensorCompatible,
431+
batch_dims: UIntTensorCompatible = 0,
432+
name: str | None = None,
433+
bad_indices_policy: Literal["", "DEFAULT", "ERROR", "IGNORE"] = "",
434+
) -> Tensor: ...
406435
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/math.pyi

+6
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,12 @@ def square(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
219219
def softplus(features: TensorCompatible, name: str | None = None) -> Tensor: ...
220220
@overload
221221
def softplus(features: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
222+
@overload
223+
def round(x: TensorCompatible, name: str | None = None) -> Tensor: ...
224+
@overload
225+
def round(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
226+
@overload
227+
def round(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
222228

223229
# Depending on the method axis is either a rank 0 tensor or a rank 0/1 tensor.
224230
def reduce_mean(
+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from tensorflow import Tensor
2+
from tensorflow._aliases import DTypeLike, TensorCompatible
3+
4+
def hamming_window(
5+
window_length: TensorCompatible, periodic: bool | TensorCompatible = True, dtype: DTypeLike = ..., name: str | None = None
6+
) -> Tensor: ...

0 commit comments

Comments
 (0)