Skip to content

Commit e1976fd

Browse files
G4Gcopybara-github
authored andcommitted
Internal change.
PiperOrigin-RevId: 419892069
1 parent bab2352 commit e1976fd

File tree

8 files changed

+67
-35
lines changed

8 files changed

+67
-35
lines changed

tensorflow_graphics/nn/loss/chamfer_distance.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121

2222
from tensorflow_graphics.util import export_api
2323
from tensorflow_graphics.util import shape
24+
from tensorflow_graphics.util import type_alias
2425

2526

26-
def evaluate(point_set_a, point_set_b, name="chamfer_distance_evaluate"):
27+
def evaluate(point_set_a: type_alias.TensorLike,
28+
point_set_b: type_alias.TensorLike,
29+
name: str = "chamfer_distance_evaluate") -> tf.Tensor:
2730
"""Computes the Chamfer distance for the given two point sets.
2831
2932
Note:

tensorflow_graphics/nn/loss/hausdorff_distance.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121

2222
from tensorflow_graphics.util import export_api
2323
from tensorflow_graphics.util import shape
24+
from tensorflow_graphics.util import type_alias
2425

2526

26-
def evaluate(point_set_a, point_set_b, name="hausdorff_distance_evaluate"):
27+
def evaluate(point_set_a: type_alias.TensorLike,
28+
point_set_b: type_alias.TensorLike,
29+
name: str = "hausdorff_distance_evaluate") -> tf.Tensor:
2730
"""Computes the Hausdorff distance from point_set_a to point_set_b.
2831
2932
Note:

tensorflow_graphics/nn/metric/fscore.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,23 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from typing import Any, Callable
21+
2022
import tensorflow as tf
2123

2224
from tensorflow_graphics.nn.metric import precision as precision_module
2325
from tensorflow_graphics.nn.metric import recall as recall_module
2426
from tensorflow_graphics.util import export_api
2527
from tensorflow_graphics.util import safe_ops
2628
from tensorflow_graphics.util import shape
29+
from tensorflow_graphics.util import type_alias
2730

2831

29-
def evaluate(ground_truth,
30-
prediction,
31-
precision_function=precision_module.evaluate,
32-
recall_function=recall_module.evaluate,
33-
name="fscore_evaluate"):
32+
def evaluate(ground_truth: type_alias.TensorLike,
33+
prediction: type_alias.TensorLike,
34+
precision_function: Callable[..., Any] = precision_module.evaluate,
35+
recall_function: Callable[..., Any] = recall_module.evaluate,
36+
name: str = "fscore_evaluate") -> tf.Tensor:
3437
"""Computes the fscore metric for the given ground truth and predicted labels.
3538
3639
The fscore is calculated as 2 * (precision * recall) / (precision + recall)

tensorflow_graphics/nn/metric/intersection_over_union.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
from tensorflow_graphics.util import asserts
2424
from tensorflow_graphics.util import export_api
2525
from tensorflow_graphics.util import shape
26+
from tensorflow_graphics.util import type_alias
2627

2728

28-
def evaluate(ground_truth_labels,
29-
predicted_labels,
30-
grid_size=1,
31-
name="intersection_over_union_evaluate"):
29+
def evaluate(ground_truth_labels: type_alias.TensorLike,
30+
predicted_labels: type_alias.TensorLike,
31+
grid_size: int = 1,
32+
name: str = "intersection_over_union_evaluate") -> tf.Tensor:
3233
"""Computes the Intersection-Over-Union metric for the given ground truth and predicted labels.
3334
3435
Note:

tensorflow_graphics/nn/metric/precision.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,26 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from typing import Any, Callable, List, Optional, Union, Tuple
21+
2022
import tensorflow as tf
2123

2224
from tensorflow_graphics.util import export_api
2325
from tensorflow_graphics.util import safe_ops
2426
from tensorflow_graphics.util import shape
27+
from tensorflow_graphics.util import type_alias
2528

2629

2730
def _cast_to_int(prediction):
2831
return tf.cast(x=prediction, dtype=tf.int32)
2932

3033

31-
def evaluate(ground_truth,
32-
prediction,
33-
classes=None,
34-
reduce_average=True,
35-
prediction_to_category_function=_cast_to_int,
36-
name="precision_evaluate"):
34+
def evaluate(ground_truth: type_alias.TensorLike,
35+
prediction: type_alias.TensorLike,
36+
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
37+
reduce_average: bool = True,
38+
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
39+
name: str = "precision_evaluate") -> tf.Tensor:
3740
"""Computes the precision metric for the given ground truth and predictions.
3841
3942
Note:

tensorflow_graphics/nn/metric/recall.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,26 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from typing import Any, Callable, List, Optional, Tuple, Union
21+
2022
import tensorflow as tf
2123

2224
from tensorflow_graphics.util import export_api
2325
from tensorflow_graphics.util import safe_ops
2426
from tensorflow_graphics.util import shape
27+
from tensorflow_graphics.util import type_alias
2528

2629

2730
def _cast_to_int(prediction):
2831
return tf.cast(x=prediction, dtype=tf.int32)
2932

3033

31-
def evaluate(ground_truth,
32-
prediction,
33-
classes=None,
34-
reduce_average=True,
35-
prediction_to_category_function=_cast_to_int,
36-
name="recall_evaluate"):
34+
def evaluate(ground_truth: type_alias.TensorLike,
35+
prediction: type_alias.TensorLike,
36+
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
37+
reduce_average: bool = True,
38+
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
39+
name: str = "recall_evaluate") -> tf.Tensor:
3740
"""Computes the recall metric for the given ground truth and predictions.
3841
3942
Note:

tensorflow_graphics/notebooks/mesh_segmentation_dataio.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424
from __future__ import division
2525
from __future__ import print_function
2626

27+
from typing import Any, Callable, Dict, List, Tuple
28+
2729
import numpy as np
2830
import tensorflow as tf
2931

3032
from tensorflow_graphics.geometry.convolution import utils as conv_utils
3133
from tensorflow_graphics.geometry.representation.mesh import utils as mesh_utils
3234
from tensorflow_graphics.util import shape
35+
from tensorflow_graphics.util import type_alias
3336

3437
DEFAULT_IO_PARAMS = {
3538
'batch_size': 8,
@@ -42,7 +45,11 @@
4245
}
4346

4447

45-
def adjacency_from_edges(edges, weights, num_edges, num_vertices):
48+
def adjacency_from_edges(
49+
edges: type_alias.TensorLike,
50+
weights: type_alias.TensorLike,
51+
num_edges: type_alias.TensorLike,
52+
num_vertices: type_alias.TensorLike) -> tf.SparseTensor:
4653
"""Returns a batched sparse 1-ring adj tensor from edge list tensor.
4754
4855
Args:
@@ -103,7 +110,9 @@ def adjacency_from_edges(edges, weights, num_edges, num_vertices):
103110
return adjacency
104111

105112

106-
def get_weighted_edges(faces, self_edges=True):
113+
def get_weighted_edges(
114+
faces: np.ndarray,
115+
self_edges: bool = True) -> Tuple[np.ndarray, np.ndarray]:
107116
r"""Gets unique edges and degree weights from a triangular mesh.
108117
109118
The shorthands used below are:
@@ -136,12 +145,12 @@ def get_weighted_edges(faces, self_edges=True):
136145
return edges, weights
137146

138147

139-
def _tfrecords_to_dataset(tfrecords,
140-
parallel_threads,
141-
shuffle,
142-
repeat,
143-
sloppy,
144-
max_readers=16):
148+
def _tfrecords_to_dataset(tfrecords: List[str],
149+
parallel_threads: int,
150+
shuffle: bool,
151+
repeat: bool,
152+
sloppy: bool,
153+
max_readers: int = 16) -> tf.data.TFRecordDataset:
145154
"""Creates a TFRecordsDataset that iterates over filenames in parallel.
146155
147156
Args:
@@ -244,7 +253,9 @@ def _parse_mesh_data(mesh_data, mean_center=True):
244253
return mesh_data
245254

246255

247-
def create_dataset_from_tfrecords(tfrecords, params):
256+
def create_dataset_from_tfrecords(
257+
tfrecords: List[str],
258+
params: Dict[str, Any]) -> tf.data.Dataset:
248259
"""Creates a mesh dataset given a list of tf records filenames.
249260
250261
Args:
@@ -309,7 +320,10 @@ def _set_default_if_none(param, param_dict, default_val):
309320
drop_remainder=is_training)
310321

311322

312-
def create_input_from_dataset(dataset_fn, files, io_params):
323+
def create_input_from_dataset(
324+
dataset_fn: Callable[..., Any],
325+
files: List[str],
326+
io_params: Dict[str, Any]) -> Tuple[Dict[str, Any], tf.Tensor]:
313327
"""Creates input function given dataset generator and input files.
314328
315329
Args:

tensorflow_graphics/notebooks/mesh_viewer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from typing import Any, Dict
22+
2123
import numpy as np
2224
from tensorflow_graphics.notebooks import threejs_visualization
2325

@@ -32,7 +34,7 @@
3234
class Viewer(object):
3335
"""A ThreeJS based viewer class for viewing 3D meshes."""
3436

35-
def _mesh_from_data(self, data):
37+
def _mesh_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
3638
"""Creates a dictionary of ThreeJS mesh objects from numpy data."""
3739
if 'vertices' not in data or 'faces' not in data:
3840
raise ValueError('Mesh Data must contain vertices and faces')
@@ -54,7 +56,7 @@ def _mesh_from_data(self, data):
5456
mesh['material'] = material
5557
return mesh
5658

57-
def __init__(self, source_mesh_data):
59+
def __init__(self, source_mesh_data: Dict[str, Any]):
5860
context = threejs_visualization.build_context()
5961
self.context = context
6062
light1 = context.THREE.PointLight.new_object(0x808080)

0 commit comments

Comments
 (0)