Skip to content

Commit 1fa2ed8

Browse files
zoyahavtfx-copybara
authored andcommitted
Address and remove remaining TF 1.15 support related TODOs.
PiperOrigin-RevId: 493487255
1 parent 3147adf commit 1fa2ed8

File tree

8 files changed

+14
-47
lines changed

8 files changed

+14
-47
lines changed

tensorflow_transform/experimental/mappers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,9 @@ def _to_term_document_one_hot(
325325
# frequency only cares the existence of a term in a document, not the
326326
# occurrence frequency within that document.
327327
# Hashing (<batch_index>, <vocab_index>) pairs for dedup.
328-
# TODO(b/160294509): Switch to tf.raw_ops.UniqueV2 to avoid hashing for tf1
329-
# tf.raw_ops.UniqueV2 always results in rank-1 tensor placeholder in tf1, even
330-
# when the input is 2D. This causes issues when applying UniqueV2 to
331-
# (<batch_index>, <vocab_index>) 2D tensor along axis=0 and using the uniqued
332-
# 2D tensor as indices for sparse tensors in tf1. See b/160294509#comment8.
333328
multiplier = vocab_size + 1
334-
unique_flatten_indices, _ = tf.unique(batch_indices * multiplier +
335-
vocab_indices)
329+
unique_flatten_indices, _ = tf.raw_ops.UniqueV2(
330+
x=batch_indices * multiplier + vocab_indices, axis=[0])
336331
unique_batch_indices = tf.cast(
337332
tf.math.divide(unique_flatten_indices, multiplier), dtype=tf.int64)
338333
unique_vocab_indices = tf.math.mod(unique_flatten_indices, multiplier)

tensorflow_transform/graph_tools_test.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,6 @@ class GraphToolsTestUniquePath(test_case.TransformTestCase):
10421042
}),
10431043
dict(
10441044
testcase_name='_y_function_of_x_with_raw_ops_while',
1045-
skip_test_check_fn=test_case.skip_if_external_environment,
10461045
create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while,
10471046
feeds=['x'],
10481047
replaced_tensors_ready={'x': False},
@@ -1092,7 +1091,6 @@ class GraphToolsTestUniquePath(test_case.TransformTestCase):
10921091
}),
10931092
dict(
10941093
testcase_name='_y_function_of_x_with_tf_while',
1095-
skip_test_check_fn=test_case.skip_if_external_environment,
10961094
create_graph_fn=_create_graph_with_tf_function_while,
10971095
feeds=['x'],
10981096
replaced_tensors_ready={'x': False},
@@ -1226,13 +1224,7 @@ def testGetUniquePath(self,
12261224
create_graph_fn,
12271225
feeds,
12281226
replaced_tensors_ready,
1229-
expected_calls_dict,
1230-
skip_test_check_fn=None):
1231-
1232-
# TODO(b/160294509): Remove this condition when TFT no longer supports TF<2.
1233-
if skip_test_check_fn:
1234-
skip_test_check_fn('This test is not currently supported.')
1235-
1227+
expected_calls_dict):
12361228
with tf.compat.v1.Graph().as_default() as graph:
12371229
tensors = create_graph_fn()
12381230
replaced_tensors_ready = [(tensors[name], ready)

tensorflow_transform/output_wrapper.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,7 @@ def __init__(self,
443443
self._exported_as_v1 = exported_as_v1
444444
self._saved_model_loader_value = None
445445
self._loaded_saved_model_graph = None
446-
# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
447-
if ops.executing_eagerly_outside_functions():
446+
if tf.compat.v1.executing_eagerly_outside_functions():
448447
# The model must be tracked by assigning to an attribute of the Keras
449448
# layer. Hence, we track the attributes of _saved_model_loader here as
450449
# well.
@@ -470,8 +469,7 @@ def _saved_model_loader(self) -> saved_transform_io_v2.SavedModelLoader:
470469
self._tft_output.transform_savedmodel_dir)
471470
self._loaded_saved_model_graph = ops.get_default_graph()
472471

473-
# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
474-
if ops.executing_eagerly_outside_functions():
472+
if tf.compat.v1.executing_eagerly_outside_functions():
475473
return self._saved_model_loader_value
476474
else:
477475
assert not self._exported_as_v1

tensorflow_transform/saved/saved_transform_io_v2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ def __init__(self, saved_model_dir: str):
128128
defined in `../constants.py` ('transform' and 'transform_signature',
129129
respectively).
130130
"""
131-
# TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is
132-
# dropped.
133-
imported = tf.compat.v2.saved_model.load(saved_model_dir)
131+
imported = tf.saved_model.load(saved_model_dir)
134132
load_v2_in_compat = constants.TRANSFORM_SIGNATURE in imported.signatures
135133
if load_v2_in_compat:
136134
restored_function = imported.signatures[constants.TRANSFORM_SIGNATURE]

tensorflow_transform/test_case.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,8 @@ def _make_placeholder(tensor_spec):
112112
return tf.compat.v1.sparse_placeholder(
113113
shape=tensor_spec.shape, dtype=tensor_spec.dtype)
114114
if isinstance(tensor_spec, tf.RaggedTensorSpec):
115-
# TODO(b/160294509): Switch to public APIs once TF 1 support is dropped.
116115
return tf.compat.v1.ragged.placeholder(
117-
tensor_spec._dtype, tensor_spec._ragged_rank, value_shape=()) # pylint: disable=protected-access
116+
tensor_spec.dtype, tensor_spec.ragged_rank, value_shape=())
118117
else:
119118
return tf.compat.v1.placeholder(
120119
shape=tensor_spec.shape, dtype=tensor_spec.dtype)
@@ -164,8 +163,7 @@ def _wrap_as_constant(value, tensor_spec):
164163
values=tf.constant(value.values, dtype=tensor_spec.dtype),
165164
dense_shape=tf.constant(value.dense_shape, dtype=tf.int64))
166165
elif isinstance(tensor_spec, tf.RaggedTensorSpec):
167-
# TODO(b/160294509): Switch to public APIs once TF 1 support is dropped.
168-
result = _ragged_value_as_constant(value, tensor_spec._dtype) # pylint: disable=protected-access
166+
result = _ragged_value_as_constant(value, tensor_spec.dtype)
169167
else:
170168
result = tf.constant(value, dtype=tensor_spec.dtype)
171169
result.shape.assert_is_compatible_with(tensor_spec.shape)
@@ -299,10 +297,10 @@ def _assertValuesCloseOrEqual(self, a_value, b_value, msg=None):
299297
if (isinstance(a_value, (bytes, str)) or isinstance(a_value, list) and
300298
a_value and isinstance(a_value[0], (bytes, str)) or
301299
isinstance(a_value, np.ndarray) and a_value.dtype == object):
302-
self.assertAllEqual(a_value, b_value)
300+
self.assertAllEqual(a_value, b_value, msg=msg)
303301
else:
304302
# TODO(varshaan): Change atol only for tests for which 1e-6 is too strict.
305-
self.assertAllClose(a_value, b_value, atol=1e-5)
303+
self.assertAllClose(a_value, b_value, atol=1e-5, msg=msg)
306304

307305
def AssertVocabularyContents(self, vocab_file_path, file_contents):
308306
if vocab_file_path.endswith('.tfrecord.gz'):

tensorflow_transform/tf2_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,18 @@
2121
from tensorflow_transform import common_types
2222
# pylint: disable=g-direct-tensorflow-import
2323
from tensorflow.python import tf2
24-
from tensorflow.python.framework import ops
2524
from tensorflow.python.framework.func_graph import FuncGraph
2625
# pylint: enable=g-direct-tensorflow-import
2726

2827

2928
def use_tf_compat_v1(force_tf_compat_v1: bool) -> bool:
3029
"""Evaluate from environment variables if TF should be used in compat.v1 mode."""
3130
major, _, _ = tf.version.VERSION.split('.')
32-
# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
3331
# If tf.enable_v2_behavior has been called, but eager execution has been
3432
# disabled, force compat v1 behavior. Hence, check
3533
# `executing_eagerly_outside_functions` as well.
3634
return (force_tf_compat_v1 or int(major) < 2 or not tf2.enabled() or
37-
not ops.executing_eagerly_outside_functions())
35+
not tf.compat.v1.executing_eagerly_outside_functions())
3836

3937

4038
def strip_and_get_tensors_and_control_dependencies(

tensorflow_transform/tf_utils.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,11 @@ def hashable_tensor_or_op(tensor_or_op):
323323
if isinstance(tensor_or_op, tf.Tensor):
324324
return tensor_or_op.experimental_ref()
325325
if isinstance(tensor_or_op, composite_tensor.CompositeTensor):
326-
# TODO(b/160294509): Use tf.type_spec_from_value - only available in TF 2.
327326
return _CompositeTensorRef(
328-
type_spec=tensor_or_op._type_spec, # pylint: disable=protected-access
327+
type_spec=tf.type_spec_from_value(tensor_or_op),
329328
list_of_refs=tuple(
330329
hashable_tensor_or_op(component) for component in tf.nest.flatten(
331-
tensor_or_op, expand_composites=True)
332-
))
330+
tensor_or_op, expand_composites=True)))
333331
return tensor_or_op
334332

335333

@@ -432,10 +430,7 @@ def reduce_batch_count(x: common_types.TensorType,
432430
dense_shape=x.dense_shape)
433431
# TODO(b/178189903): Remove this once we no longer lose static shape
434432
# information.
435-
# TODO(b/160294509): Remove the hasattr contition once TFT no longer
436-
# supports TF<2.
437-
if hasattr(x, '_dense_shape_default'):
438-
ones_like._dense_shape_default = x._dense_shape_default # pylint: disable=protected-access
433+
ones_like._dense_shape_default = x._dense_shape_default # pylint: disable=protected-access
439434
return _sparse_reduce_batch_keep_shape(tf.sparse.reduce_sum, ones_like)
440435
elif isinstance(x, tf.RaggedTensor):
441436
if reduce_instance_dims:
@@ -697,10 +692,6 @@ def _split_vocabulary_entries(batched_vocab_lines):
697692
if isinstance(split, tf.RaggedTensor):
698693
split_tensor = split.to_tensor()
699694
return split_tensor[:, 1], split_tensor[:, 0]
700-
# TODO(b/160294509): Remove this condition when TFT no longer supports TF<2.
701-
elif isinstance(split, tf.SparseTensor):
702-
split_tensor = tf.sparse.to_dense(split)
703-
return split_tensor[:, 1], split_tensor[:, 0]
704695
else:
705696
return split[1], split[0]
706697

tensorflow_transform/tf_utils_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,7 +2404,4 @@ def foo(input_tensor):
24042404

24052405

24062406
if __name__ == '__main__':
2407-
# TODO(b/160294509): Remove this once this is enabled by default in all
2408-
# supported TF versions.
2409-
tf.compat.v1.enable_v2_tensorshape()
24102407
test_case.main()

0 commit comments

Comments
 (0)