Skip to content

Commit 0c3e662

Browse files
iindyktfx-copybara
authored andcommitted
Adding telemetry for TensorRepresentations in input and output schema.
PiperOrigin-RevId: 441026555
1 parent 40adfeb commit 0c3e662

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

tensorflow_transform/beam/impl.py

+51
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from tensorflow_transform.tf_metadata import metadata_io
7979
from tensorflow_transform.tf_metadata import schema_utils
8080
from tfx_bsl.telemetry import collection as telemetry
81+
from tfx_bsl.telemetry import util as telemetry_util
8182
from tfx_bsl.tfxio import tensor_representation_util
8283
from tfx_bsl.tfxio import tensor_to_arrow
8384
from tfx_bsl.tfxio import tf_example_record
@@ -1078,6 +1079,15 @@ def expand(self, dataset):
10781079
>> telemetry.TrackRecordBatchBytes(beam_common.METRICS_NAMESPACE,
10791080
'analysis_input_bytes'))
10801081

1082+
# Gather telemetry on types of input features.
1083+
_ = (
1084+
self.pipeline | 'CreateAnalyzeInputTensorRepresentations' >>
1085+
beam.Create([input_tensor_adapter_config.tensor_representations])
1086+
|
1087+
'InstrumentAnalyzeInputTensors' >> telemetry.TrackTensorRepresentations(
1088+
telemetry_util.AppendToNamespace(beam_common.METRICS_NAMESPACE,
1089+
['analyze_input_tensors'])))
1090+
10811091
asset_map = annotators.get_asset_annotations(graph)
10821092
# TF.HUB can error when unapproved collections are present. So we explicitly
10831093
# clear out the collections in the graph.
@@ -1351,6 +1361,20 @@ def _remove_columns_from_metadata(metadata, excluded_columns):
13511361
new_feature_spec, new_domains)
13521362

13531363

1364+
class _MaybeInferTensorRepresentationsDoFn(beam.DoFn):
1365+
"""Tries to infer TensorRepresentations from a Schema."""
1366+
1367+
def process(
1368+
self, schema: schema_pb2.Schema
1369+
) -> Iterable[Dict[str, schema_pb2.TensorRepresentation]]:
1370+
try:
1371+
yield (tensor_representation_util
1372+
.InferTensorRepresentationsFromMixedSchema(schema))
1373+
except ValueError:
1374+
# Ignore any inference errors since the output is only used for metrics.
1375+
yield {}
1376+
1377+
13541378
@beam.typehints.with_input_types(Union[_DatasetElementType, pa.RecordBatch],
13551379
Union[dataset_metadata.DatasetMetadata,
13561380
TensorAdapterConfig,
@@ -1446,11 +1470,20 @@ def expand(self, dataset_and_transform_fn):
14461470
self.pipeline
14471471
| 'CreateDeferredSchema' >> beam.Create([output_metadata.schema]))
14481472

1473+
# Increment input metrics.
14491474
_ = (
14501475
input_values
14511476
| 'InstrumentInputBytes[Transform]' >> telemetry.TrackRecordBatchBytes(
14521477
beam_common.METRICS_NAMESPACE, 'transform_input_bytes'))
14531478

1479+
_ = (
1480+
self.pipeline | 'CreateTransformInputTensorRepresentations' >>
1481+
beam.Create([input_tensor_adapter_config.tensor_representations])
1482+
| 'InstrumentTransformInputTensors' >>
1483+
telemetry.TrackTensorRepresentations(
1484+
telemetry_util.AppendToNamespace(beam_common.METRICS_NAMESPACE,
1485+
['transform_input_tensors'])))
1486+
14541487
tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
14551488
type(self.pipeline.runner))
14561489
output_batches = (
@@ -1471,20 +1504,38 @@ def expand(self, dataset_and_transform_fn):
14711504
converter_pcol = (
14721505
deferred_schema | 'MakeTensorToArrowConverter' >> beam.Map(
14731506
impl_helper.make_tensor_to_arrow_converter))
1507+
1508+
output_tensor_representations = (
1509+
converter_pcol
1510+
| 'MapToTensorRepresentations' >>
1511+
beam.Map(lambda converter: converter.tensor_representations()))
1512+
14741513
output_data = (
14751514
output_batches | 'ConvertToRecordBatch' >> beam.Map(
14761515
_convert_to_record_batch,
14771516
schema=beam.pvalue.AsSingleton(deferred_schema),
14781517
converter=beam.pvalue.AsSingleton(converter_pcol),
14791518
passthrough_keys=Context.get_passthrough_keys(),
14801519
input_metadata=input_metadata))
1520+
14811521
else:
1522+
1523+
output_tensor_representations = (
1524+
deferred_schema | 'MaybeInferTensorRepresentations' >> beam.ParDo(
1525+
_MaybeInferTensorRepresentationsDoFn()))
14821526
output_data = (
14831527
output_batches | 'ConvertAndUnbatchToInstanceDicts' >> beam.FlatMap(
14841528
_convert_and_unbatch_to_instance_dicts,
14851529
schema=beam.pvalue.AsSingleton(deferred_schema),
14861530
passthrough_keys=Context.get_passthrough_keys()))
14871531

1532+
# Increment output data metrics.
1533+
_ = (
1534+
output_tensor_representations
1535+
| 'InstrumentTransformOutputTensors' >>
1536+
telemetry.TrackTensorRepresentations(
1537+
telemetry_util.AppendToNamespace(beam_common.METRICS_NAMESPACE,
1538+
['transform_output_tensors'])))
14881539
_clear_shared_state_after_barrier(self.pipeline, output_data)
14891540

14901541
return (output_data, output_metadata)

tensorflow_transform/beam/impl_test.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3526,7 +3526,8 @@ def preprocessing_fn(inputs):
35263526
})
35273527
with tft_beam.Context(temp_dir=self.get_temp_dir()):
35283528
_ = ((input_data, metadata)
3529-
| 'AnalyzeDataset' >> tft_beam.AnalyzeDataset(preprocessing_fn))
3529+
| 'AnalyzeAndTransformDataset' >>
3530+
tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
35303531

35313532
metrics = pipeline.metrics
35323533
self.assertMetricsCounterEqual(metrics, 'tft_analyzer_vocabulary', 1)
@@ -3537,6 +3538,12 @@ def preprocessing_fn(inputs):
35373538
# We check that that call is not logged.
35383539
self.assertMetricsCounterEqual(metrics, 'tft_mapper_apply_vocabulary', 0)
35393540

3541+
for namespace in ('tfx.Transform.analyze_input_tensors',
3542+
'tfx.Transform.transform_input_tensors',
3543+
'tfx.Transform.transform_output_tensors'):
3544+
self.assertMetricsCounterEqual(
3545+
metrics, 'dense_tensor', 3, namespaces_list=[namespace])
3546+
35403547
def testNumBytesCounter(self):
35413548
self._SkipIfOutputRecordBatches()
35423549

tensorflow_transform/beam/tft_unit.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,14 @@ def assertMetricsCounterEqual(self, metrics, name, expected_count,
116116
metrics_filter)['counters']
117117
committed = sum([r.committed for r in metric])
118118
attempted = sum([r.attempted for r in metric])
119-
self.assertEqual(committed, attempted)
120-
self.assertEqual(committed, expected_count)
119+
self.assertEqual(
120+
committed,
121+
attempted,
122+
msg=f'Attempted counter {name} from namespace {namespaces_list}')
123+
self.assertEqual(
124+
committed,
125+
expected_count,
126+
msg=f'Expected counter {name} from namespace {namespaces_list}')
121127

122128
def assertAnalyzerOutputs(self,
123129
input_data,

tensorflow_transform/pickle_helper.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_PROTO_CLASSES = [
3030
tf.compat.v1.ConfigProto,
3131
schema_pb2.Schema,
32+
schema_pb2.TensorRepresentation,
3233
statistics_pb2.DatasetFeatureStatistics,
3334
] + _ANNOTATION_CLASSES
3435

0 commit comments

Comments
 (0)