Skip to content

Commit 649047a

Browse files
authored
DCRBaseline and DCROverfitting throws incorrect warnings about missing columns. (#756)
1 parent f88c759 commit 649047a

File tree

6 files changed

+49
-49
lines changed

6 files changed

+49
-49
lines changed

sdmetrics/single_table/privacy/dcr_baseline_protection.py

+13-20
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import pandas as pd
77

8-
from sdmetrics._utils_metadata import _process_data_with_metadata
98
from sdmetrics.goal import Goal
109
from sdmetrics.single_table.base import SingleTableMetric
1110
from sdmetrics.single_table.privacy.dcr_utils import calculate_dcr
@@ -31,7 +30,6 @@ def _validate_inputs(
3130
cls,
3231
real_data,
3332
synthetic_data,
34-
metadata,
3533
num_rows_subsample,
3634
num_iterations,
3735
):
@@ -46,12 +44,13 @@ def _validate_inputs(
4644
num_rows_subsample = None
4745
num_iterations = 1
4846

49-
real_data_copy = real_data.copy()
50-
synthetic_data_copy = synthetic_data.copy()
51-
real_data_copy = _process_data_with_metadata(real_data_copy, metadata, True)
52-
synthetic_data_copy = _process_data_with_metadata(synthetic_data_copy, metadata, True)
47+
if not (isinstance(real_data, pd.DataFrame) and isinstance(synthetic_data, pd.DataFrame)):
48+
raise TypeError(
49+
f'Both real_data ({type(real_data)}) and synthetic_data ({type(synthetic_data)}) '
50+
'must be of type pandas.DataFrame.'
51+
)
5352

54-
return real_data_copy, synthetic_data_copy, num_rows_subsample, num_iterations
53+
return num_rows_subsample, num_iterations
5554

5655
@classmethod
5756
def compute_breakdown(
@@ -87,38 +86,32 @@ def compute_breakdown(
8786
and the median DCR score between the random data and real data.
8887
Averages of the medians are returned in the case of multiple iterations.
8988
"""
90-
sanitized_data = cls._validate_inputs(
89+
num_rows_subsample, num_iterations = cls._validate_inputs(
9190
real_data,
9291
synthetic_data,
93-
metadata,
9492
num_rows_subsample,
9593
num_iterations,
9694
)
9795

98-
sanitized_real_data = sanitized_data[0]
99-
sanitized_synthetic_data = sanitized_data[1]
100-
num_rows_subsample = sanitized_data[2]
101-
num_iterations = sanitized_data[3]
102-
103-
size_of_random_data = len(sanitized_synthetic_data)
104-
random_data = cls._generate_random_data(sanitized_real_data, size_of_random_data)
96+
size_of_random_data = len(synthetic_data)
97+
random_data = cls._generate_random_data(real_data, size_of_random_data)
10598

10699
sum_synthetic_median = 0
107100
sum_random_median = 0
108101
sum_score = 0
109102

110103
for _ in range(num_iterations):
111-
synthetic_sample = sanitized_synthetic_data
104+
synthetic_sample = synthetic_data
112105
random_sample = random_data
113106
if num_rows_subsample is not None:
114-
synthetic_sample = sanitized_synthetic_data.sample(n=num_rows_subsample)
107+
synthetic_sample = synthetic_data.sample(n=num_rows_subsample)
115108
random_sample = random_data.sample(n=num_rows_subsample)
116109

117110
dcr_real = calculate_dcr(
118-
reference_dataset=sanitized_real_data, dataset=synthetic_sample, metadata=metadata
111+
reference_dataset=real_data, dataset=synthetic_sample, metadata=metadata
119112
)
120113
dcr_random = calculate_dcr(
121-
reference_dataset=sanitized_real_data, dataset=random_sample, metadata=metadata
114+
reference_dataset=real_data, dataset=random_sample, metadata=metadata
122115
)
123116
synthetic_data_median = dcr_real.median()
124117
random_data_median = dcr_random.median()

sdmetrics/single_table/privacy/dcr_overfitting_protection.py

+18-28
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import warnings
44

55
import numpy as np
6+
import pandas as pd
67

7-
from sdmetrics._utils_metadata import _process_data_with_metadata
88
from sdmetrics.goal import Goal
99
from sdmetrics.single_table.base import SingleTableMetric
1010
from sdmetrics.single_table.privacy.dcr_utils import calculate_dcr
@@ -29,7 +29,6 @@ def _validate_inputs(
2929
real_training_data,
3030
synthetic_data,
3131
real_validation_data,
32-
metadata,
3332
num_rows_subsample,
3433
num_iterations,
3534
):
@@ -44,27 +43,25 @@ def _validate_inputs(
4443
num_rows_subsample = None
4544
num_iterations = 1
4645

46+
if not (
47+
isinstance(real_training_data, pd.DataFrame)
48+
and isinstance(synthetic_data, pd.DataFrame)
49+
and isinstance(real_validation_data, pd.DataFrame)
50+
):
51+
raise TypeError(
52+
f'All of real_training_data ({type(real_training_data)}), synthetic_data '
53+
f'({type(synthetic_data)}), and real_validation_data ({type(real_validation_data)}) '
54+
'must be of type pandas.DataFrame.'
55+
)
56+
4757
if len(real_training_data) * 0.5 > len(real_validation_data):
4858
warnings.warn(
4959
f'Your real_validation_data contains {len(real_validation_data)} rows while your '
5060
f'real_training_data contains {len(real_training_data)} rows. For most accurate '
5161
'results, we recommend that the validation data at least half the size of the training data.'
5262
)
5363

54-
real_data_copy = real_training_data.copy()
55-
synthetic_data_copy = synthetic_data.copy()
56-
real_validation_copy = real_validation_data.copy()
57-
real_data_copy = _process_data_with_metadata(real_data_copy, metadata, True)
58-
synthetic_data_copy = _process_data_with_metadata(synthetic_data_copy, metadata, True)
59-
real_validation_copy = _process_data_with_metadata(real_validation_copy, metadata, True)
60-
61-
return (
62-
real_data_copy,
63-
synthetic_data_copy,
64-
real_validation_copy,
65-
num_rows_subsample,
66-
num_iterations,
67-
)
64+
return num_rows_subsample, num_iterations
6865

6966
@classmethod
7067
def compute_breakdown(
@@ -104,34 +101,27 @@ def compute_breakdown(
104101
closer to the real dataset. Averages of the medians are returned in the case of
105102
multiple iterations.
106103
"""
107-
sanitized_data = cls._validate_inputs(
104+
num_rows_subsample, num_iterations = cls._validate_inputs(
108105
real_training_data,
109106
synthetic_data,
110107
real_validation_data,
111-
metadata,
112108
num_rows_subsample,
113109
num_iterations,
114110
)
115111

116-
training_data = sanitized_data[0]
117-
sanitized_synthetic_data = sanitized_data[1]
118-
validation_data = sanitized_data[2]
119-
num_rows_subsample = sanitized_data[3]
120-
num_iterations = sanitized_data[4]
121-
122112
sum_of_scores = 0
123113
sum_percent_close_to_real = 0
124114
sum_percent_close_to_random = 0
125115
for _ in range(num_iterations):
126-
synthetic_sample = sanitized_synthetic_data
116+
synthetic_sample = synthetic_data
127117
if num_rows_subsample is not None:
128-
synthetic_sample = sanitized_synthetic_data.sample(n=num_rows_subsample)
118+
synthetic_sample = synthetic_data.sample(n=num_rows_subsample)
129119

130120
dcr_real = calculate_dcr(
131-
reference_dataset=training_data, dataset=synthetic_sample, metadata=metadata
121+
reference_dataset=real_training_data, dataset=synthetic_sample, metadata=metadata
132122
)
133123
dcr_holdout = calculate_dcr(
134-
reference_dataset=validation_data, dataset=synthetic_sample, metadata=metadata
124+
reference_dataset=real_validation_data, dataset=synthetic_sample, metadata=metadata
135125
)
136126

137127
num_rows_closer_to_real = np.where(dcr_real < dcr_holdout, 1.0, 0.0).sum()

tests/integration/single_table/privacy/test_dcr_baseline_protection.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313

1414
class TestDCRBaselineProtection:
15+
@pytest.mark.filterwarnings('error')
1516
def test_end_to_end_with_demo(self):
1617
"""Test end to end for DCRBaslineProtection metric against the demo dataset.
1718

tests/integration/single_table/privacy/test_dcr_overfitting_protection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212

1313
class TestDCROverfittingProtection:
14+
@pytest.mark.filterwarnings('error')
1415
def test_end_to_end_with_demo(self):
1516
"""Test end to end for DCROverfittingProtection metric against the demo dataset.
1617
@@ -21,7 +22,7 @@ def test_end_to_end_with_demo(self):
2122
"""
2223
# Setup
2324
real_data, synthetic_data, metadata = load_single_table_demo()
24-
train_df, holdout_df = train_test_split(real_data, test_size=0.2)
25+
train_df, holdout_df = train_test_split(real_data, test_size=0.5)
2526

2627
# Run
2728
num_rows_subsample = 50

tests/unit/single_table/privacy/test_dcr_baseline_protection.py

+7
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def test__validate_inputs(self, test_data):
5353
with pytest.raises(ValueError, match=missing_metric):
5454
DCRBaselineProtection.compute_breakdown(no_dcr_data, no_dcr_data, no_dcr_metadata)
5555

56+
no_df_msg = re.escape(
57+
f'Both real_data ({type(None)}) and synthetic_data ({type({})}) '
58+
'must be of type pandas.DataFrame.'
59+
)
60+
with pytest.raises(TypeError, match=no_df_msg):
61+
DCRBaselineProtection.compute_breakdown(None, {}, metadata)
62+
5663
@patch(
5764
'sdmetrics.single_table.privacy.dcr_baseline_protection.DCRBaselineProtection._generate_random_data'
5865
)

tests/unit/single_table/privacy/test_dcr_overfitting_protection.py

+8
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def test__validate_inputs(self, test_data):
7171
train_data, synthetic_data, small_holdout_data, metadata
7272
)
7373

74+
no_df_msg = re.escape(
75+
f'All of real_training_data ({type(None)}), synthetic_data '
76+
f'({type({})}), and real_validation_data ({type({})}) '
77+
'must be of type pandas.DataFrame.'
78+
)
79+
with pytest.raises(TypeError, match=no_df_msg):
80+
DCROverfittingProtection.compute_breakdown(None, {}, {}, metadata)
81+
7482
@patch('numpy.where')
7583
@patch('sdmetrics.single_table.privacy.dcr_overfitting_protection.calculate_dcr')
7684
def test_compute_breakdown(self, mock_calculate_dcr, mock_numpy_where, test_data):

0 commit comments

Comments
 (0)