Skip to content

Commit cd4e2cc

Browse files
committed
Improve performance of DCR metrics (#762)
1 parent 459c3bd commit cd4e2cc

File tree

8 files changed

+90
-46
lines changed

8 files changed

+90
-46
lines changed

sdmetrics/single_table/privacy/dcr_baseline_protection.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class DCRBaselineProtection(SingleTableMetric):
2323
goal = Goal.MAXIMIZE
2424
min_value = 0.0
2525
max_value = 1.0
26+
CHUNK_SIZE = 1000
2627
_seed = None
2728

2829
@classmethod
@@ -103,15 +104,23 @@ def compute_breakdown(
103104
for _ in range(num_iterations):
104105
synthetic_sample = synthetic_data
105106
random_sample = random_data
107+
real_sample = real_data
106108
if num_rows_subsample is not None:
107109
synthetic_sample = synthetic_data.sample(n=num_rows_subsample)
108110
random_sample = random_data.sample(n=num_rows_subsample)
111+
real_sample = real_data.sample(n=num_rows_subsample)
109112

110113
dcr_real = calculate_dcr(
111-
reference_dataset=real_data, dataset=synthetic_sample, metadata=metadata
114+
reference_dataset=real_sample,
115+
dataset=synthetic_sample,
116+
metadata=metadata,
117+
chunk_size=cls.CHUNK_SIZE,
112118
)
113119
dcr_random = calculate_dcr(
114-
reference_dataset=real_data, dataset=random_sample, metadata=metadata
120+
reference_dataset=real_sample,
121+
dataset=random_sample,
122+
metadata=metadata,
123+
chunk_size=cls.CHUNK_SIZE,
115124
)
116125
synthetic_data_median = dcr_real.median()
117126
random_data_median = dcr_random.median()

sdmetrics/single_table/privacy/dcr_overfitting_protection.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class DCROverfittingProtection(SingleTableMetric):
2222
goal = Goal.MAXIMIZE
2323
min_value = 0.0
2424
max_value = 1.0
25+
CHUNK_SIZE = 1000
2526

2627
@classmethod
2728
def _validate_inputs(
@@ -114,14 +115,24 @@ def compute_breakdown(
114115
sum_percent_close_to_random = 0
115116
for _ in range(num_iterations):
116117
synthetic_sample = synthetic_data
118+
real_training_sample = real_training_data
119+
real_validation_sample = real_validation_data
117120
if num_rows_subsample is not None:
118121
synthetic_sample = synthetic_data.sample(n=num_rows_subsample)
122+
real_training_sample = real_training_data.sample(n=num_rows_subsample)
123+
real_validation_sample = real_validation_data.sample(n=num_rows_subsample)
119124

120125
dcr_real = calculate_dcr(
121-
reference_dataset=real_training_data, dataset=synthetic_sample, metadata=metadata
126+
reference_dataset=real_training_sample,
127+
dataset=synthetic_sample,
128+
metadata=metadata,
129+
chunk_size=cls.CHUNK_SIZE,
122130
)
123131
dcr_holdout = calculate_dcr(
124-
reference_dataset=real_validation_data, dataset=synthetic_sample, metadata=metadata
132+
reference_dataset=real_validation_sample,
133+
dataset=synthetic_sample,
134+
metadata=metadata,
135+
chunk_size=cls.CHUNK_SIZE,
125136
)
126137

127138
num_rows_closer_to_real = np.where(dcr_real < dcr_holdout, 1.0, 0.0).sum()

sdmetrics/single_table/privacy/dcr_utils.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
CHUNK_SIZE = 1000
1010

1111

12-
def _process_dcr_chunk(chunk, reference_copy, cols_to_keep, metadata, ranges):
13-
full_dataset = chunk.merge(reference_copy, how='cross', suffixes=('_data', '_ref'))
12+
def _process_dcr_chunk(dataset_chunk, reference_chunk, cols_to_keep, metadata, ranges):
13+
full_dataset = dataset_chunk.merge(reference_chunk, how='cross', suffixes=('_data', '_ref'))
1414

1515
for col_name in cols_to_keep:
1616
sdtype = metadata['columns'][col_name]['sdtype']
@@ -51,7 +51,7 @@ def _process_dcr_chunk(chunk, reference_copy, cols_to_keep, metadata, ranges):
5151
return chunk_result['diff']
5252

5353

54-
def calculate_dcr(dataset, reference_dataset, metadata):
54+
def calculate_dcr(dataset, reference_dataset, metadata, chunk_size=1000):
5555
"""Calculate the Distance to Closest Record for all rows in the synthetic data.
5656
5757
Arguments:
@@ -66,10 +66,10 @@ def calculate_dcr(dataset, reference_dataset, metadata):
6666
pandas.Series:
6767
Returns a Series that shows the DCR value for every row of dataset
6868
"""
69-
dataset_copy = _process_data_with_metadata(dataset.copy(), metadata, True)
70-
reference_copy = _process_data_with_metadata(reference_dataset.copy(), metadata, True)
69+
dataset = _process_data_with_metadata(dataset.copy(), metadata, True)
70+
reference = _process_data_with_metadata(reference_dataset.copy(), metadata, True)
7171

72-
common_cols = set(dataset_copy.columns) & set(reference_copy.columns)
72+
common_cols = set(dataset.columns) & set(reference.columns)
7373
cols_to_keep = []
7474
ranges = {}
7575

@@ -83,7 +83,7 @@ def calculate_dcr(dataset, reference_dataset, metadata):
8383
cols_to_keep.append(col_name)
8484

8585
if sdtype in ['numerical', 'datetime']:
86-
col_range = reference_copy[col_name].max() - reference_copy[col_name].min()
86+
col_range = reference[col_name].max() - reference[col_name].min()
8787
if isinstance(col_range, pd.Timedelta):
8888
col_range = col_range.total_seconds()
8989

@@ -92,23 +92,35 @@ def calculate_dcr(dataset, reference_dataset, metadata):
9292
if not cols_to_keep:
9393
raise ValueError('There are no overlapping statistical columns to measure.')
9494

95-
dataset_copy = dataset_copy[cols_to_keep]
96-
dataset_copy['index'] = range(len(dataset_copy))
95+
dataset = dataset[cols_to_keep]
96+
dataset['index'] = range(len(dataset))
9797

98-
reference_copy = reference_copy[cols_to_keep]
99-
reference_copy['index'] = range(len(reference_copy))
98+
reference = reference[cols_to_keep]
99+
reference['index'] = range(len(reference))
100100
results = []
101101

102-
for chunk_start in range(0, len(dataset_copy), CHUNK_SIZE):
103-
chunk = dataset_copy.iloc[chunk_start : chunk_start + CHUNK_SIZE].copy()
104-
chunk_result = _process_dcr_chunk(
105-
chunk=chunk,
106-
reference_copy=reference_copy,
107-
cols_to_keep=cols_to_keep,
108-
metadata=metadata,
109-
ranges=ranges,
110-
)
111-
results.append(chunk_result)
102+
for dataset_chunk_start in range(0, len(dataset), chunk_size):
103+
dataset_chunk = dataset.iloc[dataset_chunk_start : dataset_chunk_start + chunk_size]
104+
minimum_chunk_distance = None
105+
for reference_chunk_start in range(0, len(reference), chunk_size):
106+
reference_chunk = reference.iloc[
107+
reference_chunk_start : reference_chunk_start + chunk_size
108+
]
109+
chunk_result = _process_dcr_chunk(
110+
dataset_chunk=dataset_chunk,
111+
reference_chunk=reference_chunk,
112+
cols_to_keep=cols_to_keep,
113+
metadata=metadata,
114+
ranges=ranges,
115+
)
116+
if minimum_chunk_distance is None:
117+
minimum_chunk_distance = chunk_result
118+
else:
119+
minimum_chunk_distance = pd.Series.min(
120+
pd.concat([minimum_chunk_distance, chunk_result], axis=1), axis=1
121+
)
122+
123+
results.append(minimum_chunk_distance)
112124

113125
result = pd.concat(results, ignore_index=True)
114126
result.name = None

tests/integration/single_table/privacy/test_dcr_baseline_protection.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,12 @@ def test_end_to_end_sample_random_median(self):
171171
real_data = pd.DataFrame(data={'A': [2, 6, 3, 4, 1]})
172172
synthetic_data = pd.DataFrame(data={'A': [5, 5, 5, 5, 5]})
173173
metadata = {'columns': {'A': {'sdtype': 'numerical'}}}
174-
num_rows_sample = 1
175-
num_iterations = 5
176174

177175
# Run
178176
result = DCRBaselineProtection.compute_breakdown(
179177
real_data=real_data,
180178
synthetic_data=synthetic_data,
181179
metadata=metadata,
182-
num_rows_subsample=num_rows_sample,
183-
num_iterations=num_iterations,
184180
)
185181

186182
# Assert

tests/integration/single_table/privacy/test_dcr_overfitting_protection.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,23 @@ def test_end_to_end_with_demo(self):
2525
train_df, holdout_df = train_test_split(real_data, test_size=0.5)
2626

2727
# Run
28-
num_rows_subsample = 50
2928
compute_breakdown_result = DCROverfittingProtection.compute_breakdown(
3029
train_df, synthetic_data, holdout_df, metadata
3130
)
3231
compute_result = DCROverfittingProtection.compute(
3332
train_df, synthetic_data, holdout_df, metadata
3433
)
3534
compute_holdout_same = DCROverfittingProtection.compute_breakdown(
36-
train_df, synthetic_data, synthetic_data, metadata, num_rows_subsample
35+
train_df, synthetic_data, synthetic_data, metadata
3736
)
3837
compute_train_same = DCROverfittingProtection.compute_breakdown(
39-
synthetic_data, synthetic_data, holdout_df, metadata, num_rows_subsample
38+
synthetic_data, synthetic_data, holdout_df, metadata
4039
)
4140
compute_all_same = DCROverfittingProtection.compute_breakdown(
4241
synthetic_data,
4342
synthetic_data,
4443
synthetic_data,
4544
metadata,
46-
num_rows_subsample,
4745
)
4846

4947
synth_percentages_key = 'synthetic_data_percentages'
@@ -136,18 +134,9 @@ def test_compute_breakdown_iterations(self):
136134
compute_num_iteration_1000 = DCROverfittingProtection.compute_breakdown(
137135
train_data, synthetic_data, holdout_data, metadata, num_rows_subsample, num_iterations
138136
)
139-
compute_train_same = DCROverfittingProtection.compute_breakdown(
140-
synthetic_data,
141-
synthetic_data,
142-
holdout_data,
143-
metadata,
144-
num_rows_subsample,
145-
num_iterations,
146-
)
147137

148138
# Assert
149139
assert compute_num_iteration_1 != compute_num_iteration_1000
150-
assert compute_train_same['score'] == 0.0
151140

152141
def test_end_to_end_with_datetimes(self):
153142
"""Test end to end with datetime synthetic values."""

tests/integration/single_table/privacy/test_dcr_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pandas as pd
44

5+
from sdmetrics.demos import load_single_table_demo
56
from sdmetrics.single_table.privacy.dcr_utils import (
67
calculate_dcr,
78
)
@@ -54,3 +55,27 @@ def test_calculate_dcr_with_zero_col_range():
5455
# Assert
5556
expected_result = pd.Series([1.0, 1.0, 1.0, 0.5, 0.0])
5657
pd.testing.assert_series_equal(result, expected_result)
58+
59+
60+
def test_calculate_dcr_chunked():
61+
"""Test calculate_dcr with chunking calculations."""
62+
# Setup
63+
real_data, synthetic_data, metadata = load_single_table_demo()
64+
65+
# Run
66+
result = calculate_dcr(
67+
reference_dataset=real_data,
68+
dataset=synthetic_data,
69+
metadata=metadata,
70+
chunk_size=1000,
71+
)
72+
chunked_result = calculate_dcr(
73+
reference_dataset=real_data,
74+
dataset=synthetic_data,
75+
metadata=metadata,
76+
chunk_size=50,
77+
)
78+
79+
# Assert
80+
assert len(result) == len(real_data)
81+
pd.testing.assert_series_equal(result, chunked_result)

tests/unit/single_table/privacy/test_dcr_baseline_protection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import random
22
import re
33
from datetime import datetime
4-
from unittest.mock import patch
4+
from unittest.mock import Mock, patch
55

66
import numpy as np
77
import pandas as pd
@@ -217,7 +217,9 @@ def test_compute_breakdown_with_dcr_random_same_real(self, mock_generate_random,
217217
# Setup
218218
real_data, synthetic_data, metadata = test_data
219219
num_rows_subsample = 10
220-
mock_generate_random.return_value = real_data.copy()
220+
real_data.sample = Mock()
221+
real_data.sample.return_value = real_data.iloc[:10]
222+
mock_generate_random.return_value = real_data.iloc[:10]
221223

222224
# Run
223225
result = DCRBaselineProtection.compute_breakdown(

tests/unit/single_table/privacy/test_dcr_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def test__process_dcr_chunk(real_data, synthetic_data, test_metadata, column_ran
235235

236236
# Run
237237
result = _process_dcr_chunk(
238-
chunk=chunk,
239-
reference_copy=real_data,
238+
dataset_chunk=chunk,
239+
reference_chunk=real_data,
240240
cols_to_keep=cols_to_keep,
241241
metadata=test_metadata,
242242
ranges=column_ranges,

0 commit comments

Comments
 (0)