|
24 | 24 | from __future__ import division
|
25 | 25 | from __future__ import print_function
|
26 | 26 |
|
| 27 | +from typing import Any, Callable, Dict, List, Tuple |
| 28 | + |
27 | 29 | import numpy as np
|
28 | 30 | import tensorflow as tf
|
29 | 31 |
|
30 | 32 | from tensorflow_graphics.geometry.convolution import utils as conv_utils
|
31 | 33 | from tensorflow_graphics.geometry.representation.mesh import utils as mesh_utils
|
32 | 34 | from tensorflow_graphics.util import shape
|
| 35 | +from tensorflow_graphics.util import type_alias |
33 | 36 |
|
34 | 37 | DEFAULT_IO_PARAMS = {
|
35 | 38 | 'batch_size': 8,
|
|
42 | 45 | }
|
43 | 46 |
|
44 | 47 |
|
45 |
| -def adjacency_from_edges(edges, weights, num_edges, num_vertices): |
| 48 | +def adjacency_from_edges( |
| 49 | + edges: type_alias.TensorLike, |
| 50 | + weights: type_alias.TensorLike, |
| 51 | + num_edges: type_alias.TensorLike, |
| 52 | + num_vertices: type_alias.TensorLike) -> tf.SparseTensor: |
46 | 53 | """Returns a batched sparse 1-ring adj tensor from edge list tensor.
|
47 | 54 |
|
48 | 55 | Args:
|
@@ -103,7 +110,9 @@ def adjacency_from_edges(edges, weights, num_edges, num_vertices):
|
103 | 110 | return adjacency
|
104 | 111 |
|
105 | 112 |
|
106 |
| -def get_weighted_edges(faces, self_edges=True): |
| 113 | +def get_weighted_edges( |
| 114 | + faces: np.ndarray, |
| 115 | + self_edges: bool = True) -> Tuple[np.ndarray, np.ndarray]: |
107 | 116 | r"""Gets unique edges and degree weights from a triangular mesh.
|
108 | 117 |
|
109 | 118 | The shorthands used below are:
|
@@ -136,12 +145,12 @@ def get_weighted_edges(faces, self_edges=True):
|
136 | 145 | return edges, weights
|
137 | 146 |
|
138 | 147 |
|
139 |
| -def _tfrecords_to_dataset(tfrecords, |
140 |
| - parallel_threads, |
141 |
| - shuffle, |
142 |
| - repeat, |
143 |
| - sloppy, |
144 |
| - max_readers=16): |
| 148 | +def _tfrecords_to_dataset(tfrecords: List[str], |
| 149 | + parallel_threads: int, |
| 150 | + shuffle: bool, |
| 151 | + repeat: bool, |
| 152 | + sloppy: bool, |
| 153 | + max_readers: int = 16) -> tf.data.TFRecordDataset: |
145 | 154 | """Creates a TFRecordsDataset that iterates over filenames in parallel.
|
146 | 155 |
|
147 | 156 | Args:
|
@@ -244,7 +253,9 @@ def _parse_mesh_data(mesh_data, mean_center=True):
|
244 | 253 | return mesh_data
|
245 | 254 |
|
246 | 255 |
|
247 |
| -def create_dataset_from_tfrecords(tfrecords, params): |
| 256 | +def create_dataset_from_tfrecords( |
| 257 | + tfrecords: List[str], |
| 258 | + params: Dict[str, Any]) -> tf.data.Dataset: |
248 | 259 | """Creates a mesh dataset given a list of tf records filenames.
|
249 | 260 |
|
250 | 261 | Args:
|
@@ -309,7 +320,10 @@ def _set_default_if_none(param, param_dict, default_val):
|
309 | 320 | drop_remainder=is_training)
|
310 | 321 |
|
311 | 322 |
|
312 |
| -def create_input_from_dataset(dataset_fn, files, io_params): |
| 323 | +def create_input_from_dataset( |
| 324 | + dataset_fn: Callable[..., Any], |
| 325 | + files: List[str], |
| 326 | + io_params: Dict[str, Any]) -> Tuple[Dict[str, Any], tf.Tensor]: |
313 | 327 | """Creates input function given dataset generator and input files.
|
314 | 328 |
|
315 | 329 | Args:
|
|
0 commit comments