Skip to content

Commit 520c1de

Browse files
zoyahavtfx-copybara
authored andcommitted
Remove special handling of ragged features for TF 1.15 now that TFT requires TF >2.
PiperOrigin-RevId: 493138058
1 parent 1c82cf3 commit 520c1de

13 files changed

+951
-1159
lines changed

tensorflow_transform/beam/bucketize_integration_test.py

Lines changed: 127 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import tensorflow as tf
2222
import tensorflow_transform as tft
2323
from tensorflow_transform import analyzers
24-
from tensorflow_transform import common_types
2524
from tensorflow_transform.beam import impl as beam_impl
2625
from tensorflow_transform.beam import tft_unit
2726
from tensorflow_metadata.proto.v0 import schema_pb2
@@ -129,7 +128,26 @@ def _compute_simple_per_key_bucket(val, key, weighted=False):
129128
'x_bucketized$sparse_values': [(x - 1) // 3],
130129
'x_bucketized$sparse_indices_0': [x % 4],
131130
'x_bucketized$sparse_indices_1': [x % 5]
132-
} for x in range(1, 10)])
131+
} for x in range(1, 10)]),
132+
dict(
133+
testcase_name='ragged',
134+
input_data=[{
135+
'val': [x, 10 - x],
136+
'row_lengths': [0, x % 3, 2 - x % 3],
137+
} for x in range(1, 10)],
138+
input_metadata=tft.DatasetMetadata.from_feature_spec({
139+
'x':
140+
tf.io.RaggedFeature(
141+
tf.int64,
142+
value_key='val',
143+
partitions=[
144+
tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error
145+
]),
146+
}),
147+
expected_data=[{
148+
'x_bucketized$ragged_values': [(x - 1) // 3, (9 - x) // 3],
149+
'x_bucketized$row_lengths_1': [0, x % 3, 2 - x % 3],
150+
} for x in range(1, 10)]),
133151
]
134152

135153
_BUCKETIZE_PER_KEY_TEST_CASES = [
@@ -211,139 +229,115 @@ def _compute_simple_per_key_bucket(val, key, weighted=False):
211229
'x_bucketized':
212230
schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
213231
})),
232+
dict(
233+
testcase_name='ragged',
234+
input_data=[{
235+
'val': [x, x],
236+
'row_lengths': [x % 3, 2 - (x % 3)],
237+
'key_val': ['a', 'a'] if x < 50 else ['b', 'b'],
238+
'key_row_lengths': [x % 3, 2 - (x % 3)],
239+
} for x in range(1, 100)],
240+
input_metadata=tft.DatasetMetadata.from_feature_spec({
241+
'x':
242+
tf.io.RaggedFeature(
243+
tf.int64,
244+
value_key='val',
245+
partitions=[
246+
tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error
247+
]),
248+
'key':
249+
tf.io.RaggedFeature(
250+
tf.string,
251+
value_key='key_val',
252+
partitions=[
253+
tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error
254+
]),
255+
}),
256+
expected_data=[{
257+
'x_bucketized$ragged_values': [
258+
_compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b'),
259+
] * 2,
260+
'x_bucketized$row_lengths_1': [x % 3, 2 - (x % 3)],
261+
} for x in range(1, 100)],
262+
expected_metadata=tft.DatasetMetadata.from_feature_spec(
263+
{
264+
'x_bucketized':
265+
tf.io.RaggedFeature(
266+
tf.int64,
267+
value_key='x_bucketized$ragged_values',
268+
partitions=[
269+
tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error
270+
'x_bucketized$row_lengths_1')
271+
]),
272+
},
273+
{
274+
'x_bucketized':
275+
schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
276+
})),
277+
dict(
278+
testcase_name='ragged_weighted',
279+
input_data=[{
280+
'val': [x, x],
281+
'row_lengths': [2 - (x % 3), x % 3],
282+
'key_val': ['a', 'a'] if x < 50 else ['b', 'b'],
283+
'key_row_lengths': [
284+
2 - (x % 3),
285+
x % 3,
286+
],
287+
'weights_val':
288+
([0, 0] if x in _WEIGHTED_PER_KEY_0_RANGE else [1, 1]),
289+
'weights_row_lengths': [
290+
2 - (x % 3),
291+
x % 3,
292+
],
293+
} for x in range(1, 100)],
294+
input_metadata=tft.DatasetMetadata.from_feature_spec({
295+
'x':
296+
tf.io.RaggedFeature(
297+
tf.int64,
298+
value_key='val',
299+
partitions=[
300+
tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error
301+
]),
302+
'key':
303+
tf.io.RaggedFeature(
304+
tf.string,
305+
value_key='key_val',
306+
partitions=[
307+
tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error
308+
]),
309+
'weights':
310+
tf.io.RaggedFeature(
311+
tf.int64,
312+
value_key='weights_val',
313+
partitions=[
314+
tf.io.RaggedFeature.RowLengths('weights_row_lengths') # pytype: disable=attribute-error
315+
]),
316+
}),
317+
expected_data=[{
318+
'x_bucketized$ragged_values': [
319+
_compute_simple_per_key_bucket(
320+
x, 'a' if x < 50 else 'b', weighted=True),
321+
] * 2,
322+
'x_bucketized$row_lengths_1': [2 - (x % 3), x % 3],
323+
} for x in range(1, 100)],
324+
expected_metadata=tft.DatasetMetadata.from_feature_spec(
325+
{
326+
'x_bucketized':
327+
tf.io.RaggedFeature(
328+
tf.int64,
329+
value_key='x_bucketized$ragged_values',
330+
partitions=[
331+
tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error
332+
'x_bucketized$row_lengths_1')
333+
]),
334+
},
335+
{
336+
'x_bucketized':
337+
schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
338+
})),
214339
]
215340

216-
if common_types.is_ragged_feature_available():
217-
_BUCKETIZE_COMPOSITE_INPUT_TEST_CASES.append(
218-
dict(
219-
testcase_name='ragged',
220-
input_data=[{
221-
'val': [x, 10 - x],
222-
'row_lengths': [0, x % 3, 2 - x % 3],
223-
} for x in range(1, 10)],
224-
input_metadata=tft.DatasetMetadata.from_feature_spec({
225-
'x':
226-
tf.io.RaggedFeature(
227-
tf.int64,
228-
value_key='val',
229-
partitions=[
230-
tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error
231-
]),
232-
}),
233-
expected_data=[{
234-
'x_bucketized$ragged_values': [(x - 1) // 3, (9 - x) // 3],
235-
'x_bucketized$row_lengths_1': [0, x % 3, 2 - x % 3],
236-
} for x in range(1, 10)]))
237-
_BUCKETIZE_PER_KEY_TEST_CASES.extend([
238-
dict(
239-
testcase_name='ragged',
240-
input_data=[{
241-
'val': [x, x],
242-
'row_lengths': [x % 3, 2 - (x % 3)],
243-
'key_val': ['a', 'a'] if x < 50 else ['b', 'b'],
244-
'key_row_lengths': [x % 3, 2 - (x % 3)],
245-
} for x in range(1, 100)],
246-
input_metadata=tft.DatasetMetadata.from_feature_spec({
247-
'x':
248-
tf.io.RaggedFeature(
249-
tf.int64,
250-
value_key='val',
251-
partitions=[
252-
tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error
253-
]),
254-
'key':
255-
tf.io.RaggedFeature(
256-
tf.string,
257-
value_key='key_val',
258-
partitions=[
259-
tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error
260-
]),
261-
}),
262-
expected_data=[{
263-
'x_bucketized$ragged_values': [
264-
_compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b'),
265-
] * 2,
266-
'x_bucketized$row_lengths_1': [x % 3, 2 - (x % 3)],
267-
} for x in range(1, 100)],
268-
expected_metadata=tft.DatasetMetadata.from_feature_spec(
269-
{
270-
'x_bucketized':
271-
tf.io.RaggedFeature(
272-
tf.int64,
273-
value_key='x_bucketized$ragged_values',
274-
partitions=[
275-
tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error
276-
'x_bucketized$row_lengths_1')
277-
]),
278-
},
279-
{
280-
'x_bucketized':
281-
schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
282-
})),
283-
dict(
284-
testcase_name='ragged_weighted',
285-
input_data=[{
286-
'val': [x, x],
287-
'row_lengths': [2 - (x % 3), x % 3],
288-
'key_val': ['a', 'a'] if x < 50 else ['b', 'b'],
289-
'key_row_lengths': [
290-
2 - (x % 3),
291-
x % 3,
292-
],
293-
'weights_val':
294-
([0, 0] if x in _WEIGHTED_PER_KEY_0_RANGE else [1, 1]),
295-
'weights_row_lengths': [
296-
2 - (x % 3),
297-
x % 3,
298-
],
299-
} for x in range(1, 100)],
300-
input_metadata=tft.DatasetMetadata.from_feature_spec({
301-
'x':
302-
tf.io.RaggedFeature(
303-
tf.int64,
304-
value_key='val',
305-
partitions=[
306-
tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error
307-
]),
308-
'key':
309-
tf.io.RaggedFeature(
310-
tf.string,
311-
value_key='key_val',
312-
partitions=[
313-
tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error
314-
]),
315-
'weights':
316-
tf.io.RaggedFeature(
317-
tf.int64,
318-
value_key='weights_val',
319-
partitions=[
320-
tf.io.RaggedFeature.RowLengths('weights_row_lengths') # pytype: disable=attribute-error
321-
]),
322-
}),
323-
expected_data=[{
324-
'x_bucketized$ragged_values': [
325-
_compute_simple_per_key_bucket(
326-
x, 'a' if x < 50 else 'b', weighted=True),
327-
] * 2,
328-
'x_bucketized$row_lengths_1': [2 - (x % 3), x % 3],
329-
} for x in range(1, 100)],
330-
expected_metadata=tft.DatasetMetadata.from_feature_spec(
331-
{
332-
'x_bucketized':
333-
tf.io.RaggedFeature(
334-
tf.int64,
335-
value_key='x_bucketized$ragged_values',
336-
partitions=[
337-
tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error
338-
'x_bucketized$row_lengths_1')
339-
]),
340-
},
341-
{
342-
'x_bucketized':
343-
schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
344-
})),
345-
])
346-
347341

348342
class BucketizeIntegrationTest(tft_unit.TransformTestCase):
349343

tensorflow_transform/beam/impl.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@
9191
# once the Spark issue is resolved.
9292
from tfx_bsl.types import tfx_namedtuple
9393

94-
from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
9594
from tensorflow_metadata.proto.v0 import schema_pb2
9695

9796
# TODO(b/123325923): Fix the key type here to agree with the actual keys.
@@ -541,17 +540,10 @@ def _get_tensor_replacement_map(graph, *tensor_bindings):
541540
"""Get Tensor replacement map."""
542541
tensor_replacement_map = {}
543542

544-
is_graph_mode = not ops.executing_eagerly_outside_functions()
545543
for tensor_binding in tensor_bindings:
546544
assert isinstance(tensor_binding, _TensorBinding), tensor_binding
547-
value = tensor_binding.value
548-
# TODO(b/160294509): tf.constant doesn't accept List[np.ndarray] in TF 1.15
549-
# graph mode. Remove this condition.
550-
if (is_graph_mode and isinstance(value, list) and
551-
any(isinstance(x, np.ndarray) for x in value)):
552-
value = np.asarray(tensor_binding.value)
553545
replacement_tensor = tf.constant(
554-
value, tf.dtypes.as_dtype(tensor_binding.dtype_enum))
546+
tensor_binding.value, tf.dtypes.as_dtype(tensor_binding.dtype_enum))
555547
if graph is not None and tensor_binding.is_asset_filepath:
556548
graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
557549
replacement_tensor)

0 commit comments

Comments
 (0)