78
78
from tensorflow_transform .tf_metadata import metadata_io
79
79
from tensorflow_transform .tf_metadata import schema_utils
80
80
from tfx_bsl .telemetry import collection as telemetry
81
+ from tfx_bsl .telemetry import util as telemetry_util
81
82
from tfx_bsl .tfxio import tensor_representation_util
82
83
from tfx_bsl .tfxio import tensor_to_arrow
83
84
from tfx_bsl .tfxio import tf_example_record
@@ -1078,6 +1079,15 @@ def expand(self, dataset):
1078
1079
>> telemetry .TrackRecordBatchBytes (beam_common .METRICS_NAMESPACE ,
1079
1080
'analysis_input_bytes' ))
1080
1081
1082
+ # Gather telemetry on types of input features.
1083
+ _ = (
1084
+ self .pipeline | 'CreateAnalyzeInputTensorRepresentations' >>
1085
+ beam .Create ([input_tensor_adapter_config .tensor_representations ])
1086
+ |
1087
+ 'InstrumentAnalyzeInputTensors' >> telemetry .TrackTensorRepresentations (
1088
+ telemetry_util .AppendToNamespace (beam_common .METRICS_NAMESPACE ,
1089
+ ['analyze_input_tensors' ])))
1090
+
1081
1091
asset_map = annotators .get_asset_annotations (graph )
1082
1092
# TF.HUB can error when unapproved collections are present. So we explicitly
1083
1093
# clear out the collections in the graph.
@@ -1351,6 +1361,20 @@ def _remove_columns_from_metadata(metadata, excluded_columns):
1351
1361
new_feature_spec , new_domains )
1352
1362
1353
1363
1364
+ class _MaybeInferTensorRepresentationsDoFn (beam .DoFn ):
1365
+ """Tries to infer TensorRepresentations from a Schema."""
1366
+
1367
+ def process (
1368
+ self , schema : schema_pb2 .Schema
1369
+ ) -> Iterable [Dict [str , schema_pb2 .TensorRepresentation ]]:
1370
+ try :
1371
+ yield (tensor_representation_util
1372
+ .InferTensorRepresentationsFromMixedSchema (schema ))
1373
+ except ValueError :
1374
+ # Ignore any inference errors since the output is only used for metrics.
1375
+ yield {}
1376
+
1377
+
1354
1378
@beam .typehints .with_input_types (Union [_DatasetElementType , pa .RecordBatch ],
1355
1379
Union [dataset_metadata .DatasetMetadata ,
1356
1380
TensorAdapterConfig ,
@@ -1446,11 +1470,20 @@ def expand(self, dataset_and_transform_fn):
1446
1470
self .pipeline
1447
1471
| 'CreateDeferredSchema' >> beam .Create ([output_metadata .schema ]))
1448
1472
1473
+ # Increment input metrics.
1449
1474
_ = (
1450
1475
input_values
1451
1476
| 'InstrumentInputBytes[Transform]' >> telemetry .TrackRecordBatchBytes (
1452
1477
beam_common .METRICS_NAMESPACE , 'transform_input_bytes' ))
1453
1478
1479
+ _ = (
1480
+ self .pipeline | 'CreateTransformInputTensorRepresentations' >>
1481
+ beam .Create ([input_tensor_adapter_config .tensor_representations ])
1482
+ | 'InstrumentTransformInputTensors' >>
1483
+ telemetry .TrackTensorRepresentations (
1484
+ telemetry_util .AppendToNamespace (beam_common .METRICS_NAMESPACE ,
1485
+ ['transform_input_tensors' ])))
1486
+
1454
1487
tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE .get (
1455
1488
type (self .pipeline .runner ))
1456
1489
output_batches = (
@@ -1471,20 +1504,38 @@ def expand(self, dataset_and_transform_fn):
1471
1504
converter_pcol = (
1472
1505
deferred_schema | 'MakeTensorToArrowConverter' >> beam .Map (
1473
1506
impl_helper .make_tensor_to_arrow_converter ))
1507
+
1508
+ output_tensor_representations = (
1509
+ converter_pcol
1510
+ | 'MapToTensorRepresentations' >>
1511
+ beam .Map (lambda converter : converter .tensor_representations ()))
1512
+
1474
1513
output_data = (
1475
1514
output_batches | 'ConvertToRecordBatch' >> beam .Map (
1476
1515
_convert_to_record_batch ,
1477
1516
schema = beam .pvalue .AsSingleton (deferred_schema ),
1478
1517
converter = beam .pvalue .AsSingleton (converter_pcol ),
1479
1518
passthrough_keys = Context .get_passthrough_keys (),
1480
1519
input_metadata = input_metadata ))
1520
+
1481
1521
else :
1522
+
1523
+ output_tensor_representations = (
1524
+ deferred_schema | 'MaybeInferTensorRepresentations' >> beam .ParDo (
1525
+ _MaybeInferTensorRepresentationsDoFn ()))
1482
1526
output_data = (
1483
1527
output_batches | 'ConvertAndUnbatchToInstanceDicts' >> beam .FlatMap (
1484
1528
_convert_and_unbatch_to_instance_dicts ,
1485
1529
schema = beam .pvalue .AsSingleton (deferred_schema ),
1486
1530
passthrough_keys = Context .get_passthrough_keys ()))
1487
1531
1532
+ # Increment output data metrics.
1533
+ _ = (
1534
+ output_tensor_representations
1535
+ | 'InstrumentTransformOutputTensors' >>
1536
+ telemetry .TrackTensorRepresentations (
1537
+ telemetry_util .AppendToNamespace (beam_common .METRICS_NAMESPACE ,
1538
+ ['transform_output_tensors' ])))
1488
1539
_clear_shared_state_after_barrier (self .pipeline , output_data )
1489
1540
1490
1541
return (output_data , output_metadata )
0 commit comments