Skip to content

Commit f88c759

Browse files
authored
Optimize DCR calculation using vector operations (#754)
1 parent e45e334 commit f88c759

File tree

5 files changed

+124
-291
lines changed

5 files changed

+124
-291
lines changed

sdmetrics/single_table/privacy/dcr_baseline_protection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ def compute_breakdown(
115115
random_sample = random_data.sample(n=num_rows_subsample)
116116

117117
dcr_real = calculate_dcr(
118-
real_data=sanitized_real_data, synthetic_data=synthetic_sample, metadata=metadata
118+
reference_dataset=sanitized_real_data, dataset=synthetic_sample, metadata=metadata
119119
)
120120
dcr_random = calculate_dcr(
121-
real_data=sanitized_real_data, synthetic_data=random_sample, metadata=metadata
121+
reference_dataset=sanitized_real_data, dataset=random_sample, metadata=metadata
122122
)
123123
synthetic_data_median = dcr_real.median()
124124
random_data_median = dcr_random.median()

sdmetrics/single_table/privacy/dcr_overfitting_protection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def compute_breakdown(
128128
synthetic_sample = sanitized_synthetic_data.sample(n=num_rows_subsample)
129129

130130
dcr_real = calculate_dcr(
131-
real_data=training_data, synthetic_data=synthetic_sample, metadata=metadata
131+
reference_dataset=training_data, dataset=synthetic_sample, metadata=metadata
132132
)
133133
dcr_holdout = calculate_dcr(
134-
real_data=validation_data, synthetic_data=synthetic_sample, metadata=metadata
134+
reference_dataset=validation_data, dataset=synthetic_sample, metadata=metadata
135135
)
136136

137137
num_rows_closer_to_real = np.where(dcr_real < dcr_holdout, 1.0, 0.0).sum()
Lines changed: 79 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,116 @@
11
"""Distance to closest record measurement functions."""
22

3+
import numpy as np
34
import pandas as pd
45

56
from sdmetrics._utils_metadata import _process_data_with_metadata
67
from sdmetrics.utils import get_columns_from_metadata
78

9+
CHUNK_SIZE = 1000
810

9-
def _calculate_dcr_value(synthetic_value, real_value, sdtype, col_range=None):
10-
"""Calculate the Distance to Closest Record between two different values.
1111

12-
Arguments:
13-
synthetic_value (int, float, datetime, boolean, string, or None):
14-
The synthetic value that we are calculating DCR value for
15-
real_value (int, float, datetime, boolean, string, or None):
16-
The data value that we are referencing for measuring DCR.
17-
sdtype (string):
18-
The sdtype of the column values.
19-
col_range (float):
20-
The range of values for a column used for numerical values to calculate DCR.
21-
Defaults to None.
12+
def _process_dcr_chunk(chunk, reference_copy, cols_to_keep, metadata, ranges):
13+
full_dataset = chunk.merge(reference_copy, how='cross', suffixes=('_data', '_ref'))
2214

23-
Returns:
24-
float:
25-
Returns dcr value between two given values.
26-
"""
27-
if pd.isna(synthetic_value) and pd.isna(real_value):
28-
return 0.0
29-
elif pd.isna(synthetic_value) or pd.isna(real_value):
30-
return 1.0
31-
32-
if sdtype == 'numerical' or sdtype == 'datetime':
33-
if col_range is None:
34-
raise ValueError(
35-
'No col_range was provided. The col_range is required '
36-
'for numerical and datetime sdtype DCR calculation.'
15+
for col_name in cols_to_keep:
16+
sdtype = metadata['columns'][col_name]['sdtype']
17+
ref_column = full_dataset[col_name + '_ref']
18+
data_column = full_dataset[col_name + '_data']
19+
diff_col_name = col_name + '_diff'
20+
if sdtype in ['numerical', 'datetime']:
21+
diff = (ref_column - data_column).abs()
22+
if pd.api.types.is_timedelta64_dtype(diff):
23+
diff = diff.dt.total_seconds()
24+
25+
full_dataset[col_name + '_diff'] = np.where(
26+
ranges[col_name] == 0,
27+
(diff > 0).astype(int),
28+
np.minimum(diff / ranges[col_name], 1.0),
3729
)
3830

39-
difference = abs(synthetic_value - real_value)
40-
if isinstance(difference, pd.Timedelta):
41-
difference = difference.total_seconds()
31+
xor_condition = (ref_column.isna() & ~data_column.isna()) | (
32+
~ref_column.isna() & data_column.isna()
33+
)
4234

43-
distance = 0.0 if synthetic_value == real_value else 1.0
44-
if col_range != 0:
45-
distance = difference / col_range
35+
full_dataset.loc[xor_condition, diff_col_name] = 1
4636

47-
return min(distance, 1.0)
37+
both_nan_condition = ref_column.isna() & data_column.isna()
4838

49-
if synthetic_value == real_value:
50-
return 0.0
51-
else:
52-
return 1.0
39+
full_dataset.loc[both_nan_condition, diff_col_name] = 0
5340

41+
elif sdtype in ['categorical', 'boolean']:
42+
equals_cat = (ref_column == data_column) | (ref_column.isna() & data_column.isna())
43+
full_dataset[diff_col_name] = (~equals_cat).astype(int)
5444

55-
def _calculate_dcr_between_rows(synthetic_row, comparison_row, column_ranges, metadata):
56-
"""Calculate the Distance to Closest Record between two rows.
45+
full_dataset.drop(columns=[col_name + '_ref', col_name + '_data'], inplace=True)
5746

58-
Arguments:
59-
synthetic_row (pandas.Series):
60-
The synthetic row that we are calculating DCR value for.
61-
comparison_row (pandas.Series):
62-
The data value that we are referencing for measuring DCR.
63-
column_ranges (dict):
64-
A dictionary that defines the range for each numerical column.
65-
metadata (dict):
66-
The metadata dict.
67-
68-
Returns:
69-
float:
70-
Returns DCR value (the average value of DCR values we computed across the row).
71-
"""
72-
dcr_values = synthetic_row.index.to_series().apply(
73-
lambda synthetic_column_name: _calculate_dcr_value(
74-
synthetic_row[synthetic_column_name],
75-
comparison_row[synthetic_column_name],
76-
metadata['columns'][synthetic_column_name]['sdtype'],
77-
column_ranges.get(synthetic_column_name),
78-
)
47+
full_dataset['diff'] = full_dataset.iloc[:, 2:].sum(axis=1) / len(cols_to_keep)
48+
chunk_result = (
49+
full_dataset[['index_data', 'diff']].groupby('index_data').min().reset_index(drop=True)
7950
)
80-
81-
return dcr_values.mean()
51+
return chunk_result['diff']
8252

8353

84-
def _calculate_dcr_between_row_and_data(synthetic_row, real_data, column_ranges, metadata):
85-
"""Calculate the DCR between a single row in the synthetic data and another dataset.
54+
def calculate_dcr(dataset, reference_dataset, metadata):
55+
"""Calculate the Distance to Closest Record for all rows in the synthetic data.
8656
8757
Arguments:
88-
synthetic_row (pandas.Series):
89-
The synthetic row that we are calculating DCR against an entire dataset.
90-
real_data (pandas.Dataframe):
91-
The dataset that acts as the reference for DCR calculations.
92-
column_ranges (dict):
93-
A dictionary that defines the range for each numerical column.
58+
dataset (pandas.Dataframe):
59+
The dataset for which we want to compute the DCR values
60+
reference_dataset (pandas.Dataframe):
61+
The reference dataset that is used for the distance computations
9462
metadata (dict):
9563
The metadata dict.
9664
9765
Returns:
98-
float:
99-
Returns the minimum distance to closest record computed between the
100-
synthetic row and the reference dataset.
66+
pandas.Series:
67+
Returns a Series that shows the DCR value for every row of dataset
10168
"""
102-
synthetic_distance_to_all_real = real_data.apply(
103-
lambda real_row: _calculate_dcr_between_rows(
104-
synthetic_row, real_row, column_ranges, metadata
105-
),
106-
axis=1,
107-
)
108-
return synthetic_distance_to_all_real.min()
69+
dataset_copy = _process_data_with_metadata(dataset.copy(), metadata, True)
70+
reference_copy = _process_data_with_metadata(reference_dataset.copy(), metadata, True)
10971

72+
common_cols = set(dataset_copy.columns) & set(reference_copy.columns)
73+
cols_to_keep = []
74+
ranges = {}
11075

111-
def calculate_dcr(real_data, synthetic_data, metadata):
112-
"""Calculate the Distance to Closest Record for all rows in the synthetic data.
76+
for col_name, col_metadata in get_columns_from_metadata(metadata).items():
77+
sdtype = col_metadata['sdtype']
11378

114-
Arguments:
115-
real_data (pandas.Dataframe):
116-
The dataset that acts as the reference for DCR calculations. Ranges are determined from
117-
this dataset.
118-
synthetic_data (pandas.Dataframe):
119-
The synthetic data that we are calculating DCR values for. Every row will be measured
120-
against the comparison data.
121-
metadata (dict):
122-
The metadata dict.
79+
if (
80+
sdtype in ['numerical', 'categorical', 'boolean', 'datetime']
81+
and col_name in common_cols
82+
):
83+
cols_to_keep.append(col_name)
12384

124-
Returns:
125-
pandas.Series:
126-
Returns a Series that shows the DCR value for every row of synthetic data.
127-
"""
128-
column_ranges = {}
85+
if sdtype in ['numerical', 'datetime']:
86+
col_range = reference_copy[col_name].max() - reference_copy[col_name].min()
87+
if isinstance(col_range, pd.Timedelta):
88+
col_range = col_range.total_seconds()
12989

130-
real_data_copy = real_data.copy()
131-
synthetic_data_copy = synthetic_data.copy()
132-
real_data_copy = _process_data_with_metadata(real_data_copy, metadata, True)
133-
synthetic_data_copy = _process_data_with_metadata(synthetic_data_copy, metadata, True)
90+
ranges[col_name] = col_range
13491

135-
overlapping_columns = set(real_data_copy.columns) & set(synthetic_data_copy.columns)
136-
if not overlapping_columns:
92+
if not cols_to_keep:
13793
raise ValueError('There are no overlapping statistical columns to measure.')
13894

139-
for col_name, column in get_columns_from_metadata(metadata).items():
140-
sdtype = column['sdtype']
141-
col_range = None
142-
if sdtype == 'numerical' or sdtype == 'datetime':
143-
col_range = real_data_copy[col_name].max() - real_data_copy[col_name].min()
144-
if isinstance(col_range, pd.Timedelta):
145-
col_range = col_range.total_seconds()
146-
147-
column_ranges[col_name] = col_range
148-
149-
dcr_dist_df = synthetic_data_copy.apply(
150-
lambda synth_row: _calculate_dcr_between_row_and_data(
151-
synth_row, real_data_copy, column_ranges, metadata
152-
),
153-
axis=1,
154-
)
95+
dataset_copy = dataset_copy[cols_to_keep]
96+
dataset_copy['index'] = range(len(dataset_copy))
97+
98+
reference_copy = reference_copy[cols_to_keep]
99+
reference_copy['index'] = range(len(reference_copy))
100+
results = []
101+
102+
for chunk_start in range(0, len(dataset_copy), CHUNK_SIZE):
103+
chunk = dataset_copy.iloc[chunk_start : chunk_start + CHUNK_SIZE].copy()
104+
chunk_result = _process_dcr_chunk(
105+
chunk=chunk,
106+
reference_copy=reference_copy,
107+
cols_to_keep=cols_to_keep,
108+
metadata=metadata,
109+
ranges=ranges,
110+
)
111+
results.append(chunk_result)
112+
113+
result = pd.concat(results, ignore_index=True)
114+
result.name = None
155115

156-
return dcr_dist_df
116+
return result

tests/integration/single_table/privacy/test_dcr_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_calculate_dcr():
2020
metadata = {'columns': {'num_col': {'sdtype': 'numerical'}}}
2121

2222
# Run
23-
result = calculate_dcr(real_data=real_df, synthetic_data=synthetic_df_diff, metadata=metadata)
23+
result = calculate_dcr(reference_dataset=real_df, dataset=synthetic_df_diff, metadata=metadata)
2424

2525
# Assert
2626
expected_result = pd.Series([0.2, 0.0])
@@ -49,7 +49,7 @@ def test_calculate_dcr_with_zero_col_range():
4949
metadata = {'columns': {'num_col': {'sdtype': 'numerical'}, 'date_col': {'sdtype': 'datetime'}}}
5050

5151
# Run
52-
result = calculate_dcr(real_data=real_df, synthetic_data=synthetic_df_diff, metadata=metadata)
52+
result = calculate_dcr(reference_dataset=real_df, dataset=synthetic_df_diff, metadata=metadata)
5353

5454
# Assert
5555
expected_result = pd.Series([1.0, 1.0, 1.0, 0.5, 0.0])

0 commit comments

Comments
 (0)