Skip to content

Commit 55dd99a

Browse files
authored
tensorflow Add tensorflow.summary module (#11358)
Partially derived from https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/summary.pyi
1 parent 9b260a5 commit 55dd99a

File tree

4 files changed

+82
-1
lines changed

4 files changed

+82
-1
lines changed

stubs/tensorflow/@tests/stubtest_allowlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ tensorflow.io.SparseFeature.__new__
9191

9292
# Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented.
9393
tensorflow.io.TFRecordWriter
94+
tensorflow.experimental.dtensor.Mesh
9495

9596
# stubtest does not pass for protobuf generated stubs.
9697
tensorflow.train.Example.*

stubs/tensorflow/tensorflow/_aliases.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2
2929

3030
Slice: TypeAlias = int | slice | None
3131
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
32+
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
3233
StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence]
3334
ScalarTensorCompatible: TypeAlias = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any]
3435

@@ -52,4 +53,6 @@ ContainerInputSpec: TypeAlias = ContainerGeneric[InputSpec]
5253

5354
AnyArray: TypeAlias = npt.NDArray[Any]
5455
FloatArray: TypeAlias = npt.NDArray[np.float_ | np.float16 | np.float32 | np.float64]
55-
IntArray: TypeAlias = npt.NDArray[np.int_ | np.uint8 | np.int32 | np.int64]
56+
UIntArray: TypeAlias = npt.NDArray[np.uint | np.uint8 | np.uint16 | np.uint32 | np.uint64]
57+
SignedIntArray: TypeAlias = npt.NDArray[np.int_ | np.int8 | np.int16 | np.int32 | np.int64]
58+
IntArray: TypeAlias = UIntArray | SignedIntArray
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from _typeshed import Incomplete
2+
3+
from tensorflow._aliases import IntArray, IntDataSequence
4+
5+
class Mesh:
6+
def __init__(
7+
self,
8+
dim_names: list[str],
9+
global_device_ids: IntArray | IntDataSequence,
10+
local_device_ids: list[int],
11+
local_devices: list[Incomplete | str],
12+
mesh_name: str = "",
13+
global_devices: list[Incomplete | str] | None = None,
14+
use_xla_spmd: bool = False,
15+
) -> None: ...
16+
17+
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import abc
2+
from _typeshed import Incomplete
3+
from collections.abc import Callable, Iterator
4+
from contextlib import AbstractContextManager, contextmanager
5+
from typing import Literal
6+
from typing_extensions import Self
7+
8+
import tensorflow as tf
9+
from tensorflow._aliases import FloatArray, IntArray
10+
from tensorflow.experimental.dtensor import Mesh
11+
12+
class SummaryWriter(metaclass=abc.ABCMeta):
13+
def as_default(self, step: int | None = None) -> AbstractContextManager[Self]: ...
14+
def close(self) -> None: ...
15+
def flush(self) -> None: ...
16+
def init(self) -> None: ...
17+
def set_as_default(self, step: int | None = None) -> None: ...
18+
19+
def audio(
20+
name: str,
21+
data: tf.Tensor,
22+
sample_rate: int | tf.Tensor,
23+
step: int | tf.Tensor | None = None,
24+
max_outputs: int | tf.Tensor | None = 3,
25+
encoding: Literal["wav"] | None = None,
26+
description: str | None = None,
27+
) -> bool: ...
28+
def create_file_writer(
29+
logdir: str,
30+
max_queue: int | None = None,
31+
flush_millis: int | None = None,
32+
filename_suffix: str | None = None,
33+
name: str | None = None,
34+
experimental_trackable: bool = False,
35+
experimental_mesh: Mesh | None = None,
36+
) -> SummaryWriter: ...
37+
def create_noop_writer() -> SummaryWriter: ...
38+
def flush(writer: SummaryWriter | None = None, name: str | None = None) -> tf.Operation: ...
39+
def graph(graph_data: tf.Graph | tf.compat.v1.GraphDef) -> bool: ...
40+
def histogram(
41+
name: str, data: tf.Tensor, step: int | None = None, buckets: int | None = None, description: str | None = None
42+
) -> bool: ...
43+
def image(
44+
name: str,
45+
data: tf.Tensor | FloatArray | IntArray,
46+
step: int | tf.Tensor | None = None,
47+
max_outputs: int | None = 3,
48+
description: str | None = None,
49+
) -> bool: ...
50+
@contextmanager
51+
def record_if(condition: bool | tf.Tensor | Callable[[], bool]) -> Iterator[None]: ...
52+
def scalar(name: str, data: float | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...
53+
def should_record_summaries() -> bool: ...
54+
def text(name: str, data: str | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...
55+
def trace_export(name: str, step: int | tf.Tensor | None = None, profiler_outdir: str | None = None) -> None: ...
56+
def trace_off() -> None: ...
57+
def trace_on(graph: bool = True, profiler: bool = False) -> None: ...
58+
def write(
59+
tag: str, tensor: tf.Tensor, step: int | tf.Tensor | None = None, metadata: Incomplete | None = None, name: str | None = None
60+
) -> bool: ...

0 commit comments

Comments
 (0)