3
3
import warnings
4
4
5
5
import numpy as np
6
+ import pandas as pd
6
7
7
- from sdmetrics ._utils_metadata import _process_data_with_metadata
8
8
from sdmetrics .goal import Goal
9
9
from sdmetrics .single_table .base import SingleTableMetric
10
10
from sdmetrics .single_table .privacy .dcr_utils import calculate_dcr
@@ -29,7 +29,6 @@ def _validate_inputs(
29
29
real_training_data ,
30
30
synthetic_data ,
31
31
real_validation_data ,
32
- metadata ,
33
32
num_rows_subsample ,
34
33
num_iterations ,
35
34
):
@@ -44,27 +43,25 @@ def _validate_inputs(
44
43
num_rows_subsample = None
45
44
num_iterations = 1
46
45
46
+ if not (
47
+ isinstance (real_training_data , pd .DataFrame )
48
+ and isinstance (synthetic_data , pd .DataFrame )
49
+ and isinstance (real_validation_data , pd .DataFrame )
50
+ ):
51
+ raise TypeError (
52
+ f'All of real_training_data ({ type (real_training_data )} ), synthetic_data '
53
+ f'({ type (synthetic_data )} ), and real_validation_data ({ type (real_validation_data )} ) '
54
+ 'must be of type pandas.DataFrame.'
55
+ )
56
+
47
57
if len (real_training_data ) * 0.5 > len (real_validation_data ):
48
58
warnings .warn (
49
59
f'Your real_validation_data contains { len (real_validation_data )} rows while your '
50
60
f'real_training_data contains { len (real_training_data )} rows. For most accurate '
51
61
'results, we recommend that the validation data at least half the size of the training data.'
52
62
)
53
63
54
- real_data_copy = real_training_data .copy ()
55
- synthetic_data_copy = synthetic_data .copy ()
56
- real_validation_copy = real_validation_data .copy ()
57
- real_data_copy = _process_data_with_metadata (real_data_copy , metadata , True )
58
- synthetic_data_copy = _process_data_with_metadata (synthetic_data_copy , metadata , True )
59
- real_validation_copy = _process_data_with_metadata (real_validation_copy , metadata , True )
60
-
61
- return (
62
- real_data_copy ,
63
- synthetic_data_copy ,
64
- real_validation_copy ,
65
- num_rows_subsample ,
66
- num_iterations ,
67
- )
64
+ return num_rows_subsample , num_iterations
68
65
69
66
@classmethod
70
67
def compute_breakdown (
@@ -104,34 +101,27 @@ def compute_breakdown(
104
101
closer to the real dataset. Averages of the medians are returned in the case of
105
102
multiple iterations.
106
103
"""
107
- sanitized_data = cls ._validate_inputs (
104
+ num_rows_subsample , num_iterations = cls ._validate_inputs (
108
105
real_training_data ,
109
106
synthetic_data ,
110
107
real_validation_data ,
111
- metadata ,
112
108
num_rows_subsample ,
113
109
num_iterations ,
114
110
)
115
111
116
- training_data = sanitized_data [0 ]
117
- sanitized_synthetic_data = sanitized_data [1 ]
118
- validation_data = sanitized_data [2 ]
119
- num_rows_subsample = sanitized_data [3 ]
120
- num_iterations = sanitized_data [4 ]
121
-
122
112
sum_of_scores = 0
123
113
sum_percent_close_to_real = 0
124
114
sum_percent_close_to_random = 0
125
115
for _ in range (num_iterations ):
126
- synthetic_sample = sanitized_synthetic_data
116
+ synthetic_sample = synthetic_data
127
117
if num_rows_subsample is not None :
128
- synthetic_sample = sanitized_synthetic_data .sample (n = num_rows_subsample )
118
+ synthetic_sample = synthetic_data .sample (n = num_rows_subsample )
129
119
130
120
dcr_real = calculate_dcr (
131
- reference_dataset = training_data , dataset = synthetic_sample , metadata = metadata
121
+ reference_dataset = real_training_data , dataset = synthetic_sample , metadata = metadata
132
122
)
133
123
dcr_holdout = calculate_dcr (
134
- reference_dataset = validation_data , dataset = synthetic_sample , metadata = metadata
124
+ reference_dataset = real_validation_data , dataset = synthetic_sample , metadata = metadata
135
125
)
136
126
137
127
num_rows_closer_to_real = np .where (dcr_real < dcr_holdout , 1.0 , 0.0 ).sum ()
0 commit comments