Skip to content

Commit 076ff26

Browse files
committed
Replace DataFrame.drop_duplicates with dictionary_encode and np.unique
PiperOrigin-RevId: 296922505
1 parent 776881e commit 076ff26

File tree

3 files changed

+99
-10
lines changed

3 files changed

+99
-10
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
## Bug Fixes and Other Changes
1818

1919
* Fix facets visualization.
20+
* Optimize LiftStatsGenerator for string features.
2021

2122
## Breaking Changes
2223

tensorflow_data_validation/statistics/generators/lift_stats_generator.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,35 @@ def _get_example_value_presence(
133133
return None
134134

135135
arr_flat = arr.flatten()
136+
is_binary_like = arrow_util.is_binary_like(arr_flat.type)
137+
assert boundaries is None or not is_binary_like, (
138+
'Boundaries can only be applied to numeric columns')
139+
if is_binary_like:
140+
# use dictionary_encode so we can use np.unique on object arrays
141+
dict_array = arr_flat.dictionary_encode()
142+
arr_flat = dict_array.indices
143+
arr_flat_dict = np.asarray(dict_array.dictionary)
136144
example_indices_flat = example_indices[
137145
array_util.GetFlattenedArrayParentIndices(arr).to_numpy()]
138146
if boundaries is not None:
139147
element_indices, bins = bin_util.bin_array(arr_flat, boundaries)
140-
df = pd.DataFrame({
141-
'example_indices': example_indices_flat[element_indices],
142-
'values': bins
143-
})
148+
pairs = np.vstack([example_indices_flat[element_indices], bins])
144149
else:
145-
df = pd.DataFrame({
146-
'example_indices': example_indices_flat,
147-
'values': np.asarray(arr_flat)
148-
})
149-
df_unique = df.drop_duplicates()
150-
return df_unique.set_index('example_indices')['values']
150+
pairs = np.vstack([example_indices_flat, np.asarray(arr_flat)])
151+
if not pairs.size:
152+
return None
153+
# Deduplicate values which show up more than once in the same example. This
154+
# makes P(X=x|Y=y) in the standard lift definition behave as
155+
# P(x \in Xs | y \in Ys) if examples contain more than one value of X and Y.
156+
unique_pairs = np.unique(pairs, axis=1)
157+
example_indices = unique_pairs[0, :]
158+
values = unique_pairs[1, :]
159+
if is_binary_like:
160+
# return binary like values a pd.Categorical wrapped in a Series. This makes
161+
# subsqeuent operations like pd.Merge cheaper.
162+
values = pd.Categorical.from_codes(values, categories=arr_flat_dict)
163+
return pd.Series(values, name='values',
164+
index=pd.Index(example_indices, name='example_indices'))
151165

152166

153167
def _to_partial_copresence_counts(

tensorflow_data_validation/statistics/generators/lift_stats_generator_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
from absl.testing import absltest
22+
import numpy as np
2223
import pandas as pd
2324
import pyarrow as pa
2425

@@ -46,6 +47,21 @@ def test_example_value_presence(self):
4647
lift_stats_generator._get_example_value_presence(
4748
t, types.FeaturePath(['x']), boundaries=None))
4849

50+
def test_example_value_presence_string_value(self):
51+
t = pa.Table.from_arrays([
52+
pa.array([['a'], ['a', 'a'], ['a', 'b'], ['b']]),
53+
], ['x'])
54+
expected_cat = pd.Categorical.from_codes([0, 0, 0, 1, 1],
55+
categories=['a', 'b'])
56+
expected_series = pd.Series(expected_cat,
57+
name='values',
58+
index=pd.Index([0, 1, 2, 2, 3],
59+
name='example_indices'))
60+
pd.testing.assert_series_equal(
61+
expected_series,
62+
lift_stats_generator._get_example_value_presence(
63+
t, types.FeaturePath(['x']), boundaries=None))
64+
4965
def test_example_value_presence_none_value(self):
5066
t = pa.Table.from_arrays([
5167
pa.array([[1], None]),
@@ -709,6 +725,64 @@ def test_lift_null_y(self):
709725
add_default_slice_key_to_input=True,
710726
add_default_slice_key_to_output=True)
711727

728+
def test_lift_missing_x_and_y(self):
729+
examples = [
730+
pa.Table.from_arrays([
731+
# explicitly construct type to avoid treating as null type
732+
pa.array([], type=pa.list_(pa.binary())),
733+
pa.array([], type=pa.list_(pa.binary())),
734+
], ['categorical_x', 'string_y']),
735+
]
736+
schema = text_format.Parse(
737+
"""
738+
feature {
739+
name: 'categorical_x'
740+
type: BYTES
741+
}
742+
feature {
743+
name: 'string_y'
744+
type: BYTES
745+
}
746+
""", schema_pb2.Schema())
747+
expected_result = []
748+
generator = lift_stats_generator.LiftStatsGenerator(
749+
schema=schema, y_path=types.FeaturePath(['string_y']))
750+
self.assertSlicingAwareTransformOutputEqual(
751+
examples,
752+
generator,
753+
expected_result,
754+
add_default_slice_key_to_input=True,
755+
add_default_slice_key_to_output=True)
756+
757+
def test_lift_float_y_is_nan(self):
758+
# after calling bin_array, this is effectively an empty array.
759+
examples = [
760+
pa.Table.from_arrays([
761+
pa.array([['a']]),
762+
pa.array([[np.nan]]),
763+
], ['categorical_x', 'float_y']),
764+
]
765+
schema = text_format.Parse(
766+
"""
767+
feature {
768+
name: 'categorical_x'
769+
type: BYTES
770+
}
771+
feature {
772+
name: 'float_y'
773+
type: FLOAT
774+
}
775+
""", schema_pb2.Schema())
776+
expected_result = []
777+
generator = lift_stats_generator.LiftStatsGenerator(
778+
schema=schema, y_path=types.FeaturePath(['float_y']), y_boundaries=[1])
779+
self.assertSlicingAwareTransformOutputEqual(
780+
examples,
781+
generator,
782+
expected_result,
783+
add_default_slice_key_to_input=True,
784+
add_default_slice_key_to_output=True)
785+
712786
def test_lift_min_x_count(self):
713787
examples = [
714788
pa.Table.from_arrays([

0 commit comments

Comments
 (0)