Skip to content

Commit 8d9423d

Browse files
tf-transform-teamzoyahav
authored andcommitted
Project import generated by Copybara.
PiperOrigin-RevId: 202529631
1 parent 6ca34b6 commit 8d9423d

File tree

8 files changed

+183
-94
lines changed

8 files changed

+183
-94
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ other *untested* combinations may also work.
5353

5454
|tensorflow-transform |tensorflow |apache-beam[gcp]|
5555
|--------------------------------------------------------------------------------|--------------|----------------|
56-
|[GitHub master](https://github.com/tensorflow/transform/blob/master/RELEASE.md) |nightly (1.x) |2.4.0 |
56+
|[GitHub master](https://github.com/tensorflow/transform/blob/master/RELEASE.md) |nightly (1.x) |2.5.0 |
57+
|[0.8.0](https://github.com/tensorflow/transform/blob/v0.8.0/RELEASE.md) |1.8 |2.5.0 |
5758
|[0.6.0](https://github.com/tensorflow/transform/blob/v0.6.0/RELEASE.md) |1.6 |2.4.0 |
5859
|[0.5.0](https://github.com/tensorflow/transform/blob/v0.5.0/RELEASE.md) |1.5 |2.3.0 |
5960
|[0.4.0](https://github.com/tensorflow/transform/blob/v0.4.0/RELEASE.md) |1.4 |2.2.0 |

RELEASE.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Current version (not yet released; still in development)
1+
# Release 0.8.0
22

33
## Major Features and Improvements
44
* Add TFTransformOutput utility class that wraps the output of tf.Transform for
@@ -25,12 +25,7 @@
2525
e.g. `tft.coders.ExampleProtoCoder`.
2626
* Setting dtypes for numpy arrays in `tft.coders.ExampleProtoCoder` and
2727
`tft.coders.CsvCoder`.
28-
* tft.mean now supports SparseTensor when reduce_instance_dimensions=True.
29-
In this case it returns a scalar mean computed over the non-missing values of
30-
the SparseTensor.
31-
* tft.mean now supports SparseTensor when reduce_instance_dimensions=False.
32-
In this case it returns a vector mean computed over the non-missing values of
33-
the SparseTensor.
28+
* `tft.mean`, `tft.max` and `tft.var` now support `tf.SparseTensor`.
3429
* Update examples to use "core" TensorFlow estimator API (`tf.estimator`).
3530
* Depends on `protobuf>=3.6.0<4`.
3631

examples/census_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def convert_label(label):
208208
| 'FixCommasTestData' >> beam.Map(
209209
lambda line: line.replace(', ', ','))
210210
| 'RemoveTrailingPeriodsTestData' >> beam.Map(lambda line: line[:-1])
211-
| 'DecodeTestData' >> beam.Map(converter.decode))
211+
| 'DecodeTestData' >> MapAndFilterErrors(converter.decode))
212212

213213
raw_test_dataset = (raw_test_data, RAW_DATA_METADATA)
214214

examples/sentiment.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ example, the data in this example uses a single feature for the full text of a
3030
movie review. This is split into sentences using the `tf.string_split`
3131
function. The `tf.string_split` function takes a rank 1 tensor and converts it
3232
to a rank 2 `SparseTensor` that contains the individual tokens. Then, using
33-
`tft.string_to_int`, this `SparseTensor` is converted to a
33+
`tft.compute_and_apply_vocabulary`, this `SparseTensor` is converted to a
3434
`SparseTensor` of `int64`s with the same shape.
3535

3636
During the training and evaluation phase, the `SparseTensor` that represents

examples/sentiment_example.md

Lines changed: 0 additions & 57 deletions
This file was deleted.

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Package Setup script for the tf.Transform binary.
14+
"""Package Setup script for tf.Transform.
1515
"""
1616
from setuptools import find_packages
1717
from setuptools import setup
1818

1919
# Tensorflow transform version.
20-
__version__ = '0.8.0dev'
20+
__version__ = '0.8.0'
2121

2222

2323
def _make_required_install_packages():
2424
return [
2525
'absl-py>=0.1.6',
26-
'apache-beam[gcp]>=2.4,<3',
27-
'numpy>=1.10,<2',
26+
'apache-beam[gcp]>=2.5,<3',
27+
'numpy>=1.13.3,<2',
2828

2929
# TF now requires protobuf>=3.6.0.
3030
'protobuf>=3.6.0,<4',
3131

32-
'six>=1.9,<2',
32+
'six>=1.10,<2',
3333

3434
]
3535

tensorflow_transform/analyzers.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,32 @@ def max(x, reduce_instance_dims=True, name=None): # pylint: disable=redefined-b
382382
Returns:
383383
A `Tensor`. Has the same type as `x`.
384384
"""
385-
return _numeric_combine([x], np.max, reduce_instance_dims, name)[0]
385+
combine_fn = np.max
386+
if isinstance(x, tf.SparseTensor):
387+
if reduce_instance_dims:
388+
x = x.values
389+
else:
390+
sparse_ones = tf.SparseTensor(
391+
indices=x.indices,
392+
values=tf.ones_like(x.values),
393+
dense_shape=x.dense_shape)
394+
ones_values = tf.sparse_reduce_sum(sparse_ones, axis=0, keep_dims=True)
395+
# sparse_reduce_max returns 0 when all
396+
# elements are missing along axis 0.
397+
# Replace the 0 with nan when float
398+
# and dtype.min when int.
399+
batch_has_no_values = tf.equal(ones_values, tf.cast(0, x.dtype))
400+
x = tf.sparse_reduce_max(x, axis=0, keep_dims=True)
401+
if x.dtype == tf.float32:
402+
missing_value = np.nan
403+
combine_fn = np.nanmax
404+
elif x.dtype == tf.float64:
405+
missing_value = np.float64(np.nan)
406+
combine_fn = np.nanmax
407+
else:
408+
missing_value = x.dtype.min
409+
x = tf.where(batch_has_no_values, tf.fill(tf.shape(x), missing_value), x)
410+
return _numeric_combine([x], combine_fn, reduce_instance_dims, name)[0]
386411

387412

388413
def _min_and_max(x, reduce_instance_dims=True, name=None): # pylint: disable=redefined-builtin
@@ -500,8 +525,8 @@ def var(x, reduce_instance_dims=True, name=None, output_dtype=None):
500525
(x - mean(x))**2 / length(x).
501526
502527
Args:
503-
x: A `Tensor`. Its type must be floating point (float{16|32|64}), or
504-
integral ([u]int{8|16|32|64}).
528+
x: `Tensor` or `SparseTensor`. Its type must be floating point
529+
(float{16|32|64}), or integral ([u]int{8|16|32|64}).
505530
reduce_instance_dims: By default collapses the batch and instance dimensions
506531
to arrive at a single scalar output. If False, only collapses the batch
507532
dimension and outputs a vector of the same shape as the input.
@@ -517,23 +542,34 @@ def var(x, reduce_instance_dims=True, name=None, output_dtype=None):
517542
TypeError: If the type of `x` is not supported.
518543
"""
519544
with tf.name_scope(name, 'var'):
520-
# Note: Calling `mean`, `sum`, and `size` as defined in this module, not the
521-
# builtins.
522-
x_mean = mean(x, reduce_instance_dims, output_dtype=output_dtype)
523-
# x_mean will be float16, float32, or float64, depending on type of x.
524-
squared_deviations = tf.square(tf.cast(x, x_mean.dtype) - x_mean)
525-
return mean(
526-
squared_deviations, reduce_instance_dims, output_dtype=output_dtype)
545+
return _mean_and_var(x, reduce_instance_dims, name, output_dtype)[1]
527546

528547

529548
def _mean_and_var(x, reduce_instance_dims=True, name=None, output_dtype=None):
530549
"""More efficient combined `mean` and `var`. See `var`."""
550+
if output_dtype is None:
551+
output_dtype = _MEAN_OUTPUT_DTYPE_MAP.get(x.dtype)
552+
if output_dtype is None:
553+
raise TypeError('Tensor type %r is not supported' % x.dtype)
531554
with tf.name_scope(name, 'mean_and_var'):
532555
# Note: Calling `mean`, `sum`, and `size` as defined in this module, not the
533556
# builtins.
534557
x_mean = mean(x, reduce_instance_dims, output_dtype=output_dtype)
535-
# x_mean will be float16, float32, or float64, depending on type of x.
536-
squared_deviations = tf.square(tf.cast(x, x_mean.dtype) - x_mean)
558+
if isinstance(x, tf.SparseTensor):
559+
if reduce_instance_dims:
560+
squared_deviations = tf.square(tf.cast(x.values, x_mean.dtype) - x_mean)
561+
else:
562+
# Only supports sparsetensors with rank 2.
563+
x.get_shape().assert_has_rank(2)
564+
mean_values = tf.gather(x_mean, x.indices[:, 1])
565+
squared_deviation_values = tf.square(
566+
tf.cast(x.values, x_mean.dtype) - mean_values)
567+
squared_deviations = tf.SparseTensor(
568+
indices=x.indices,
569+
values=squared_deviation_values,
570+
dense_shape=x.dense_shape)
571+
else:
572+
squared_deviations = tf.square(tf.cast(x, x_mean.dtype) - x_mean)
537573
x_var = mean(
538574
squared_deviations, reduce_instance_dims, output_dtype=output_dtype)
539575
return x_mean, x_var

tensorflow_transform/beam/impl_test.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,66 @@ def analyzer_fn(inputs):
12821282
self.assertAnalyzerOutputs(
12831283
input_data, input_metadata, analyzer_fn, expected_outputs)
12841284

1285+
def testMaxWithSparseTensorReduceTrue(self):
1286+
1287+
def analyzer_fn(inputs):
1288+
return {'max': tft.max(inputs['sparse'])}
1289+
1290+
input_data = [{
1291+
'sparse': ([0, 1], [0., 1.])
1292+
}, {
1293+
'sparse': ([1, 3], [2., 3.])
1294+
}]
1295+
input_metadata = dataset_metadata.DatasetMetadata({
1296+
'sparse':
1297+
sch.ColumnSchema(
1298+
tf.float32, [4],
1299+
sch.SparseColumnRepresentation(
1300+
'val', [sch.SparseIndexField('idx', False)]))
1301+
})
1302+
expected_outputs = {'max': np.array(3., np.float32)}
1303+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
1304+
expected_outputs)
1305+
1306+
@tft_unit.parameters(
1307+
(tf.int32,),
1308+
(tf.int64,),
1309+
(tf.float32,),
1310+
(tf.float64,),
1311+
)
1312+
def testMaxWithSparseTensorReduceFalse(self, input_dtype):
1313+
1314+
def analyzer_fn(inputs):
1315+
return {'max': tft.max(inputs['sparse'], False)}
1316+
1317+
input_data = [{
1318+
'sparse': ([0, 1], [-1., 1.])
1319+
}, {
1320+
'sparse': ([1, 3], [2., 3.])
1321+
}]
1322+
input_metadata = dataset_metadata.DatasetMetadata({
1323+
'sparse':
1324+
sch.ColumnSchema(
1325+
input_dtype, [4],
1326+
sch.SparseColumnRepresentation(
1327+
'val', [sch.SparseIndexField('idx', False)]))
1328+
})
1329+
if input_dtype == tf.float32 or input_dtype == tf.float64:
1330+
expected_outputs = {
1331+
'max':
1332+
np.array([-1., 2., float('nan'), 3.], input_dtype.as_numpy_dtype)
1333+
}
1334+
else:
1335+
expected_outputs = {
1336+
'max':
1337+
np.array(
1338+
[-1, 2, np.iinfo(input_dtype.as_numpy_dtype).min, 3],
1339+
input_dtype.as_numpy_dtype)
1340+
}
1341+
1342+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
1343+
expected_outputs)
1344+
12851345
def testNumericMeanWithSparseTensorReduceTrue(self):
12861346

12871347
def analyzer_fn(inputs):
@@ -1341,6 +1401,70 @@ def analyzer_fn(inputs):
13411401
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
13421402
expected_outputs)
13431403

1404+
@tft_unit.parameters(
1405+
(tf.int32,),
1406+
(tf.int64,),
1407+
(tf.float32,),
1408+
(tf.float64,),
1409+
)
1410+
def testVarWithSparseTensorReduceInstanceDimsTrue(self, input_dtype):
1411+
1412+
def analyzer_fn(inputs):
1413+
return {'var': tft.var(inputs['sparse'])}
1414+
1415+
input_data = [{
1416+
'sparse': ([0, 1], [0., 1.])
1417+
}, {
1418+
'sparse': ([1, 3], [2., 3.])
1419+
}]
1420+
input_metadata = dataset_metadata.DatasetMetadata({
1421+
'sparse':
1422+
sch.ColumnSchema(
1423+
input_dtype, [4],
1424+
sch.SparseColumnRepresentation(
1425+
'val', [sch.SparseIndexField('idx', False)]))
1426+
})
1427+
if input_dtype == tf.float64:
1428+
expected_outputs = {'var': np.array(1.25, np.float64)}
1429+
else:
1430+
expected_outputs = {'var': np.array(1.25, np.float32)}
1431+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
1432+
expected_outputs)
1433+
1434+
@tft_unit.parameters(
1435+
(tf.int32,),
1436+
(tf.int64,),
1437+
(tf.float32,),
1438+
(tf.float64,),
1439+
)
1440+
def testVarWithSparseTensorReduceInstanceDimsFalse(self, input_dtype):
1441+
1442+
def analyzer_fn(inputs):
1443+
return {'var': tft.var(inputs['sparse'], reduce_instance_dims=False)}
1444+
1445+
input_data = [{
1446+
'sparse': ([0, 1], [0., 1.])
1447+
}, {
1448+
'sparse': ([1, 3], [2., 3.])
1449+
}]
1450+
input_metadata = dataset_metadata.DatasetMetadata({
1451+
'sparse':
1452+
sch.ColumnSchema(
1453+
input_dtype, [4],
1454+
sch.SparseColumnRepresentation(
1455+
'val', [sch.SparseIndexField('idx', False)]))
1456+
})
1457+
if input_dtype == tf.float64:
1458+
expected_outputs = {
1459+
'var': np.array([0., .25, float('nan'), 0.], np.float64)
1460+
}
1461+
else:
1462+
expected_outputs = {
1463+
'var': np.array([0., .25, float('nan'), 0.], np.float32)
1464+
}
1465+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
1466+
expected_outputs)
1467+
13441468
def testNumericAnalyzersWithSparseInputs(self):
13451469
def repeat(in_tensor, value):
13461470
batch_size = tf.shape(in_tensor)[0]
@@ -1358,11 +1482,6 @@ def min_fn(inputs):
13581482
return {'min': repeat(inputs['a'], tft.min(inputs['a']))}
13591483
_ = input_dataset | beam_impl.AnalyzeDataset(min_fn)
13601484

1361-
with self.assertRaises(TypeError):
1362-
def max_fn(inputs):
1363-
return {'max': repeat(inputs['a'], tft.max(inputs['a']))}
1364-
_ = input_dataset | beam_impl.AnalyzeDataset(max_fn)
1365-
13661485
with self.assertRaises(TypeError):
13671486
def sum_fn(inputs):
13681487
return {'sum': repeat(inputs['a'], tft.sum(inputs['a']))}
@@ -1373,11 +1492,6 @@ def size_fn(inputs):
13731492
return {'size': repeat(inputs['a'], tft.size(inputs['a']))}
13741493
_ = input_dataset | beam_impl.AnalyzeDataset(size_fn)
13751494

1376-
with self.assertRaises(TypeError):
1377-
def var_fn(inputs):
1378-
return {'var': repeat(inputs['a'], tft.var(inputs['a']))}
1379-
_ = input_dataset | beam_impl.AnalyzeDataset(var_fn)
1380-
13811495
def testStringToTFIDF(self):
13821496
def preprocessing_fn(inputs):
13831497
inputs_as_ints = tft.compute_and_apply_vocabulary(

0 commit comments

Comments
 (0)