|
21 | 21 | import tensorflow as tf
|
22 | 22 | import tensorflow_transform as tft
|
23 | 23 | from tensorflow_transform import analyzers
|
24 |
| -from tensorflow_transform import common_types |
25 | 24 | from tensorflow_transform.beam import impl as beam_impl
|
26 | 25 | from tensorflow_transform.beam import tft_unit
|
27 | 26 | from tensorflow_metadata.proto.v0 import schema_pb2
|
@@ -129,7 +128,26 @@ def _compute_simple_per_key_bucket(val, key, weighted=False):
|
129 | 128 | 'x_bucketized$sparse_values': [(x - 1) // 3],
|
130 | 129 | 'x_bucketized$sparse_indices_0': [x % 4],
|
131 | 130 | 'x_bucketized$sparse_indices_1': [x % 5]
|
132 |
| - } for x in range(1, 10)]) |
| 131 | + } for x in range(1, 10)]), |
| 132 | + dict( |
| 133 | + testcase_name='ragged', |
| 134 | + input_data=[{ |
| 135 | + 'val': [x, 10 - x], |
| 136 | + 'row_lengths': [0, x % 3, 2 - x % 3], |
| 137 | + } for x in range(1, 10)], |
| 138 | + input_metadata=tft.DatasetMetadata.from_feature_spec({ |
| 139 | + 'x': |
| 140 | + tf.io.RaggedFeature( |
| 141 | + tf.int64, |
| 142 | + value_key='val', |
| 143 | + partitions=[ |
| 144 | + tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error |
| 145 | + ]), |
| 146 | + }), |
| 147 | + expected_data=[{ |
| 148 | + 'x_bucketized$ragged_values': [(x - 1) // 3, (9 - x) // 3], |
| 149 | + 'x_bucketized$row_lengths_1': [0, x % 3, 2 - x % 3], |
| 150 | + } for x in range(1, 10)]), |
133 | 151 | ]
|
134 | 152 |
|
135 | 153 | _BUCKETIZE_PER_KEY_TEST_CASES = [
|
@@ -211,139 +229,115 @@ def _compute_simple_per_key_bucket(val, key, weighted=False):
|
211 | 229 | 'x_bucketized':
|
212 | 230 | schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
|
213 | 231 | })),
|
| 232 | + dict( |
| 233 | + testcase_name='ragged', |
| 234 | + input_data=[{ |
| 235 | + 'val': [x, x], |
| 236 | + 'row_lengths': [x % 3, 2 - (x % 3)], |
| 237 | + 'key_val': ['a', 'a'] if x < 50 else ['b', 'b'], |
| 238 | + 'key_row_lengths': [x % 3, 2 - (x % 3)], |
| 239 | + } for x in range(1, 100)], |
| 240 | + input_metadata=tft.DatasetMetadata.from_feature_spec({ |
| 241 | + 'x': |
| 242 | + tf.io.RaggedFeature( |
| 243 | + tf.int64, |
| 244 | + value_key='val', |
| 245 | + partitions=[ |
| 246 | + tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error |
| 247 | + ]), |
| 248 | + 'key': |
| 249 | + tf.io.RaggedFeature( |
| 250 | + tf.string, |
| 251 | + value_key='key_val', |
| 252 | + partitions=[ |
| 253 | + tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error |
| 254 | + ]), |
| 255 | + }), |
| 256 | + expected_data=[{ |
| 257 | + 'x_bucketized$ragged_values': [ |
| 258 | + _compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b'), |
| 259 | + ] * 2, |
| 260 | + 'x_bucketized$row_lengths_1': [x % 3, 2 - (x % 3)], |
| 261 | + } for x in range(1, 100)], |
| 262 | + expected_metadata=tft.DatasetMetadata.from_feature_spec( |
| 263 | + { |
| 264 | + 'x_bucketized': |
| 265 | + tf.io.RaggedFeature( |
| 266 | + tf.int64, |
| 267 | + value_key='x_bucketized$ragged_values', |
| 268 | + partitions=[ |
| 269 | + tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error |
| 270 | + 'x_bucketized$row_lengths_1') |
| 271 | + ]), |
| 272 | + }, |
| 273 | + { |
| 274 | + 'x_bucketized': |
| 275 | + schema_pb2.IntDomain(min=0, max=2, is_categorical=True), |
| 276 | + })), |
| 277 | + dict( |
| 278 | + testcase_name='ragged_weighted', |
| 279 | + input_data=[{ |
| 280 | + 'val': [x, x], |
| 281 | + 'row_lengths': [2 - (x % 3), x % 3], |
| 282 | + 'key_val': ['a', 'a'] if x < 50 else ['b', 'b'], |
| 283 | + 'key_row_lengths': [ |
| 284 | + 2 - (x % 3), |
| 285 | + x % 3, |
| 286 | + ], |
| 287 | + 'weights_val': |
| 288 | + ([0, 0] if x in _WEIGHTED_PER_KEY_0_RANGE else [1, 1]), |
| 289 | + 'weights_row_lengths': [ |
| 290 | + 2 - (x % 3), |
| 291 | + x % 3, |
| 292 | + ], |
| 293 | + } for x in range(1, 100)], |
| 294 | + input_metadata=tft.DatasetMetadata.from_feature_spec({ |
| 295 | + 'x': |
| 296 | + tf.io.RaggedFeature( |
| 297 | + tf.int64, |
| 298 | + value_key='val', |
| 299 | + partitions=[ |
| 300 | + tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error |
| 301 | + ]), |
| 302 | + 'key': |
| 303 | + tf.io.RaggedFeature( |
| 304 | + tf.string, |
| 305 | + value_key='key_val', |
| 306 | + partitions=[ |
| 307 | + tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error |
| 308 | + ]), |
| 309 | + 'weights': |
| 310 | + tf.io.RaggedFeature( |
| 311 | + tf.int64, |
| 312 | + value_key='weights_val', |
| 313 | + partitions=[ |
| 314 | + tf.io.RaggedFeature.RowLengths('weights_row_lengths') # pytype: disable=attribute-error |
| 315 | + ]), |
| 316 | + }), |
| 317 | + expected_data=[{ |
| 318 | + 'x_bucketized$ragged_values': [ |
| 319 | + _compute_simple_per_key_bucket( |
| 320 | + x, 'a' if x < 50 else 'b', weighted=True), |
| 321 | + ] * 2, |
| 322 | + 'x_bucketized$row_lengths_1': [2 - (x % 3), x % 3], |
| 323 | + } for x in range(1, 100)], |
| 324 | + expected_metadata=tft.DatasetMetadata.from_feature_spec( |
| 325 | + { |
| 326 | + 'x_bucketized': |
| 327 | + tf.io.RaggedFeature( |
| 328 | + tf.int64, |
| 329 | + value_key='x_bucketized$ragged_values', |
| 330 | + partitions=[ |
| 331 | + tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error |
| 332 | + 'x_bucketized$row_lengths_1') |
| 333 | + ]), |
| 334 | + }, |
| 335 | + { |
| 336 | + 'x_bucketized': |
| 337 | + schema_pb2.IntDomain(min=0, max=2, is_categorical=True), |
| 338 | + })), |
214 | 339 | ]
|
215 | 340 |
|
216 |
| -if common_types.is_ragged_feature_available(): |
217 |
| - _BUCKETIZE_COMPOSITE_INPUT_TEST_CASES.append( |
218 |
| - dict( |
219 |
| - testcase_name='ragged', |
220 |
| - input_data=[{ |
221 |
| - 'val': [x, 10 - x], |
222 |
| - 'row_lengths': [0, x % 3, 2 - x % 3], |
223 |
| - } for x in range(1, 10)], |
224 |
| - input_metadata=tft.DatasetMetadata.from_feature_spec({ |
225 |
| - 'x': |
226 |
| - tf.io.RaggedFeature( |
227 |
| - tf.int64, |
228 |
| - value_key='val', |
229 |
| - partitions=[ |
230 |
| - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error |
231 |
| - ]), |
232 |
| - }), |
233 |
| - expected_data=[{ |
234 |
| - 'x_bucketized$ragged_values': [(x - 1) // 3, (9 - x) // 3], |
235 |
| - 'x_bucketized$row_lengths_1': [0, x % 3, 2 - x % 3], |
236 |
| - } for x in range(1, 10)])) |
237 |
| - _BUCKETIZE_PER_KEY_TEST_CASES.extend([ |
238 |
| - dict( |
239 |
| - testcase_name='ragged', |
240 |
| - input_data=[{ |
241 |
| - 'val': [x, x], |
242 |
| - 'row_lengths': [x % 3, 2 - (x % 3)], |
243 |
| - 'key_val': ['a', 'a'] if x < 50 else ['b', 'b'], |
244 |
| - 'key_row_lengths': [x % 3, 2 - (x % 3)], |
245 |
| - } for x in range(1, 100)], |
246 |
| - input_metadata=tft.DatasetMetadata.from_feature_spec({ |
247 |
| - 'x': |
248 |
| - tf.io.RaggedFeature( |
249 |
| - tf.int64, |
250 |
| - value_key='val', |
251 |
| - partitions=[ |
252 |
| - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error |
253 |
| - ]), |
254 |
| - 'key': |
255 |
| - tf.io.RaggedFeature( |
256 |
| - tf.string, |
257 |
| - value_key='key_val', |
258 |
| - partitions=[ |
259 |
| - tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error |
260 |
| - ]), |
261 |
| - }), |
262 |
| - expected_data=[{ |
263 |
| - 'x_bucketized$ragged_values': [ |
264 |
| - _compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b'), |
265 |
| - ] * 2, |
266 |
| - 'x_bucketized$row_lengths_1': [x % 3, 2 - (x % 3)], |
267 |
| - } for x in range(1, 100)], |
268 |
| - expected_metadata=tft.DatasetMetadata.from_feature_spec( |
269 |
| - { |
270 |
| - 'x_bucketized': |
271 |
| - tf.io.RaggedFeature( |
272 |
| - tf.int64, |
273 |
| - value_key='x_bucketized$ragged_values', |
274 |
| - partitions=[ |
275 |
| - tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error |
276 |
| - 'x_bucketized$row_lengths_1') |
277 |
| - ]), |
278 |
| - }, |
279 |
| - { |
280 |
| - 'x_bucketized': |
281 |
| - schema_pb2.IntDomain(min=0, max=2, is_categorical=True), |
282 |
| - })), |
283 |
| - dict( |
284 |
| - testcase_name='ragged_weighted', |
285 |
| - input_data=[{ |
286 |
| - 'val': [x, x], |
287 |
| - 'row_lengths': [2 - (x % 3), x % 3], |
288 |
| - 'key_val': ['a', 'a'] if x < 50 else ['b', 'b'], |
289 |
| - 'key_row_lengths': [ |
290 |
| - 2 - (x % 3), |
291 |
| - x % 3, |
292 |
| - ], |
293 |
| - 'weights_val': |
294 |
| - ([0, 0] if x in _WEIGHTED_PER_KEY_0_RANGE else [1, 1]), |
295 |
| - 'weights_row_lengths': [ |
296 |
| - 2 - (x % 3), |
297 |
| - x % 3, |
298 |
| - ], |
299 |
| - } for x in range(1, 100)], |
300 |
| - input_metadata=tft.DatasetMetadata.from_feature_spec({ |
301 |
| - 'x': |
302 |
| - tf.io.RaggedFeature( |
303 |
| - tf.int64, |
304 |
| - value_key='val', |
305 |
| - partitions=[ |
306 |
| - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error |
307 |
| - ]), |
308 |
| - 'key': |
309 |
| - tf.io.RaggedFeature( |
310 |
| - tf.string, |
311 |
| - value_key='key_val', |
312 |
| - partitions=[ |
313 |
| - tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error |
314 |
| - ]), |
315 |
| - 'weights': |
316 |
| - tf.io.RaggedFeature( |
317 |
| - tf.int64, |
318 |
| - value_key='weights_val', |
319 |
| - partitions=[ |
320 |
| - tf.io.RaggedFeature.RowLengths('weights_row_lengths') # pytype: disable=attribute-error |
321 |
| - ]), |
322 |
| - }), |
323 |
| - expected_data=[{ |
324 |
| - 'x_bucketized$ragged_values': [ |
325 |
| - _compute_simple_per_key_bucket( |
326 |
| - x, 'a' if x < 50 else 'b', weighted=True), |
327 |
| - ] * 2, |
328 |
| - 'x_bucketized$row_lengths_1': [2 - (x % 3), x % 3], |
329 |
| - } for x in range(1, 100)], |
330 |
| - expected_metadata=tft.DatasetMetadata.from_feature_spec( |
331 |
| - { |
332 |
| - 'x_bucketized': |
333 |
| - tf.io.RaggedFeature( |
334 |
| - tf.int64, |
335 |
| - value_key='x_bucketized$ragged_values', |
336 |
| - partitions=[ |
337 |
| - tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error |
338 |
| - 'x_bucketized$row_lengths_1') |
339 |
| - ]), |
340 |
| - }, |
341 |
| - { |
342 |
| - 'x_bucketized': |
343 |
| - schema_pb2.IntDomain(min=0, max=2, is_categorical=True), |
344 |
| - })), |
345 |
| - ]) |
346 |
| - |
347 | 341 |
|
348 | 342 | class BucketizeIntegrationTest(tft_unit.TransformTestCase):
|
349 | 343 |
|
|
0 commit comments