Skip to content

Commit 3147adf

Browse files
zoyahavtfx-copybara
authored andcommitted
Remove is_vocabulary_tfrecord_supported and remove instances of raising an error when vocabulary is called with 'tfrecord_gzip' format when it's not.
PiperOrigin-RevId: 493447110
1 parent 520c1de commit 3147adf

File tree

6 files changed

+0
-60
lines changed

6 files changed

+0
-60
lines changed

tensorflow_transform/analyzers.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,11 +2010,6 @@ def _vocabulary_analyzer_nodes(
20102010
vocabulary_key: Optional[str] = None
20112011
) -> common_types.TemporaryAnalyzerOutputType:
20122012
"""Internal helper for analyzing vocab. See `vocabulary` doc string."""
2013-
if (file_format == 'tfrecord_gzip' and
2014-
not tf_utils.is_vocabulary_tfrecord_supported()):
2015-
raise ValueError(
2016-
'Vocabulary file_format "tfrecord_gzip" not yet supported for '
2017-
f'{tf.version.VERSION}.')
20182013

20192014
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
20202015
analyzer_inputs)

tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,40 +14,13 @@
1414
# limitations under the License.
1515
"""Tests for tfrecord_gzip tft.vocabulary and tft.compute_and_apply_vocabulary."""
1616

17-
import tensorflow as tf
18-
from tensorflow_transform import tf2_utils
19-
from tensorflow_transform import tf_utils
2017
from tensorflow_transform.beam import tft_unit
2118
from tensorflow_transform.beam import vocabulary_integration_test
2219

23-
import unittest
24-
25-
mock = tf.compat.v1.test.mock
26-
2720

2821
class TFRecordVocabularyIntegrationTest(
2922
vocabulary_integration_test.VocabularyIntegrationTest):
3023

31-
def setUp(self):
32-
# TODO(b/164921571): Remove mock once tfrecord vocabularies are supported in
33-
# all TF versions.
34-
if not tf2_utils.use_tf_compat_v1(force_tf_compat_v1=False):
35-
self.is_vocabulary_tfrecord_supported_patch = mock.patch(
36-
'tensorflow_transform.tf_utils.is_vocabulary_tfrecord_supported')
37-
mock_is_vocabulary_tfrecord_supported = (
38-
self.is_vocabulary_tfrecord_supported_patch.start())
39-
mock_is_vocabulary_tfrecord_supported.side_effect = lambda: True
40-
41-
if (tft_unit.is_external_environment() and
42-
not tf_utils.is_vocabulary_tfrecord_supported()):
43-
raise unittest.SkipTest('Test requires async DatasetInitializer')
44-
super().setUp()
45-
46-
def tearDown(self):
47-
if not tf2_utils.use_tf_compat_v1(force_tf_compat_v1=False):
48-
self.is_vocabulary_tfrecord_supported_patch.stop()
49-
super().tearDown()
50-
5124
def _VocabFormat(self):
5225
return 'tfrecord_gzip'
5326

tensorflow_transform/experimental/analyzers.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,6 @@ def _approximate_vocabulary_analyzer_nodes(
431431
file_format: common_types.VocabularyFileFormatType,
432432
vocabulary_key: str) -> common_types.TemporaryAnalyzerOutputType:
433433
"""Internal helper for analyzing vocab. See `vocabulary` doc string."""
434-
if (file_format == 'tfrecord_gzip' and
435-
not tf_utils.is_vocabulary_tfrecord_supported()):
436-
raise ValueError(
437-
'Vocabulary file_format "tfrecord_gzip" requires TF version >= 2.4')
438-
439434
# TODO(b/208879020): Add vocabulary size annotation for this analyzer.
440435
analyzers.register_vocab(
441436
vocab_filename, vocabulary_key=vocabulary_key, file_format=file_format)

tensorflow_transform/mappers.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,11 +1088,6 @@ def apply_vocabulary(
10881088
starting from zero, and string value not in the vocabulary is
10891089
assigned default_value.
10901090
"""
1091-
if (file_format == 'tfrecord_gzip' and
1092-
not tf_utils.is_vocabulary_tfrecord_supported()):
1093-
raise ValueError(
1094-
'Vocabulary file_format "tfrecord_gzip" not yet supported for '
1095-
f'{tf.version.VERSION}.')
10961091
with tf.compat.v1.name_scope(name, 'apply_vocab'):
10971092
if x.dtype != tf.string and not x.dtype.is_integer:
10981093
raise ValueError('expected tf.string or tf.int[8|16|32|64] but got %r' %

tensorflow_transform/tf_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import enum
1818
from typing import Callable, Optional, Tuple, Union
1919

20-
from packaging import version
2120
import tensorflow as tf
2221
from tensorflow_transform import annotators
2322
from tensorflow_transform import common_types
@@ -536,16 +535,6 @@ def reorder_histogram(bucket_vocab: tf.Tensor, counts: tf.Tensor,
536535
return tf.gather(counts, ordering)
537536

538537

539-
# TODO(b/62379925): Remove this once all supported TF versions have
540-
# tf.data.experimental.DatasetInitializer.
541-
def is_vocabulary_tfrecord_supported() -> bool:
542-
if isinstance(ops.get_default_graph(), func_graph.FuncGraph):
543-
return False
544-
return ((hasattr(tf.data.experimental, 'DatasetInitializer') or
545-
hasattr(tf.lookup.experimental, 'DatasetInitializer')) and
546-
version.parse(tf.version.VERSION) >= version.parse('2.4'))
547-
548-
549538
# Used to decide which bucket boundary index to assign to a value.
550539
class Side(enum.Enum):
551540
RIGHT = 'right'

tensorflow_transform/tf_utils_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tensorflow_transform import tf_utils
2424
from tensorflow_transform import test_case
2525

26-
import unittest
2726
from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
2827

2928
_CONSTRUCT_TABLE_PARAMETERS = [
@@ -2273,12 +2272,6 @@ def _to_idf(df, corpus_size):
22732272

22742273
class VocabTFUtilsTest(test_case.TransformTestCase):
22752274

2276-
def setUp(self):
2277-
if (not tf_utils.is_vocabulary_tfrecord_supported() and
2278-
test_case.is_external_environment()):
2279-
raise unittest.SkipTest('Test requires DatasetInitializer')
2280-
super().setUp()
2281-
22822275
def _write_tfrecords(self, path, bytes_records):
22832276
with tf.io.TFRecordWriter(path, 'GZIP') as writer:
22842277
for record in bytes_records:

0 commit comments

Comments
 (0)