|
1 | 1 | """Distance to closest record measurement functions."""
|
2 | 2 |
|
| 3 | +import numpy as np |
3 | 4 | import pandas as pd
|
4 | 5 |
|
5 | 6 | from sdmetrics._utils_metadata import _process_data_with_metadata
|
6 | 7 | from sdmetrics.utils import get_columns_from_metadata
|
7 | 8 |
|
| 9 | +CHUNK_SIZE = 1000 |
8 | 10 |
|
9 |
| -def _calculate_dcr_value(synthetic_value, real_value, sdtype, col_range=None): |
10 |
| - """Calculate the Distance to Closest Record between two different values. |
11 | 11 |
|
12 |
| - Arguments: |
13 |
| - synthetic_value (int, float, datetime, boolean, string, or None): |
14 |
| - The synthetic value that we are calculating DCR value for |
15 |
| - real_value (int, float, datetime, boolean, string, or None): |
16 |
| - The data value that we are referencing for measuring DCR. |
17 |
| - sdtype (string): |
18 |
| - The sdtype of the column values. |
19 |
| - col_range (float): |
20 |
| - The range of values for a column used for numerical values to calculate DCR. |
21 |
| - Defaults to None. |
| 12 | +def _process_dcr_chunk(chunk, reference_copy, cols_to_keep, metadata, ranges): |
| 13 | + full_dataset = chunk.merge(reference_copy, how='cross', suffixes=('_data', '_ref')) |
22 | 14 |
|
23 |
| - Returns: |
24 |
| - float: |
25 |
| - Returns dcr value between two given values. |
26 |
| - """ |
27 |
| - if pd.isna(synthetic_value) and pd.isna(real_value): |
28 |
| - return 0.0 |
29 |
| - elif pd.isna(synthetic_value) or pd.isna(real_value): |
30 |
| - return 1.0 |
31 |
| - |
32 |
| - if sdtype == 'numerical' or sdtype == 'datetime': |
33 |
| - if col_range is None: |
34 |
| - raise ValueError( |
35 |
| - 'No col_range was provided. The col_range is required ' |
36 |
| - 'for numerical and datetime sdtype DCR calculation.' |
| 15 | + for col_name in cols_to_keep: |
| 16 | + sdtype = metadata['columns'][col_name]['sdtype'] |
| 17 | + ref_column = full_dataset[col_name + '_ref'] |
| 18 | + data_column = full_dataset[col_name + '_data'] |
| 19 | + diff_col_name = col_name + '_diff' |
| 20 | + if sdtype in ['numerical', 'datetime']: |
| 21 | + diff = (ref_column - data_column).abs() |
| 22 | + if pd.api.types.is_timedelta64_dtype(diff): |
| 23 | + diff = diff.dt.total_seconds() |
| 24 | + |
| 25 | + full_dataset[col_name + '_diff'] = np.where( |
| 26 | + ranges[col_name] == 0, |
| 27 | + (diff > 0).astype(int), |
| 28 | + np.minimum(diff / ranges[col_name], 1.0), |
37 | 29 | )
|
38 | 30 |
|
39 |
| - difference = abs(synthetic_value - real_value) |
40 |
| - if isinstance(difference, pd.Timedelta): |
41 |
| - difference = difference.total_seconds() |
| 31 | + xor_condition = (ref_column.isna() & ~data_column.isna()) | ( |
| 32 | + ~ref_column.isna() & data_column.isna() |
| 33 | + ) |
42 | 34 |
|
43 |
| - distance = 0.0 if synthetic_value == real_value else 1.0 |
44 |
| - if col_range != 0: |
45 |
| - distance = difference / col_range |
| 35 | + full_dataset.loc[xor_condition, diff_col_name] = 1 |
46 | 36 |
|
47 |
| - return min(distance, 1.0) |
| 37 | + both_nan_condition = ref_column.isna() & data_column.isna() |
48 | 38 |
|
49 |
| - if synthetic_value == real_value: |
50 |
| - return 0.0 |
51 |
| - else: |
52 |
| - return 1.0 |
| 39 | + full_dataset.loc[both_nan_condition, diff_col_name] = 0 |
53 | 40 |
|
| 41 | + elif sdtype in ['categorical', 'boolean']: |
| 42 | + equals_cat = (ref_column == data_column) | (ref_column.isna() & data_column.isna()) |
| 43 | + full_dataset[diff_col_name] = (~equals_cat).astype(int) |
54 | 44 |
|
55 |
| -def _calculate_dcr_between_rows(synthetic_row, comparison_row, column_ranges, metadata): |
56 |
| - """Calculate the Distance to Closest Record between two rows. |
| 45 | + full_dataset.drop(columns=[col_name + '_ref', col_name + '_data'], inplace=True) |
57 | 46 |
|
58 |
| - Arguments: |
59 |
| - synthetic_row (pandas.Series): |
60 |
| - The synthetic row that we are calculating DCR value for. |
61 |
| - comparison_row (pandas.Series): |
62 |
| - The data value that we are referencing for measuring DCR. |
63 |
| - column_ranges (dict): |
64 |
| - A dictionary that defines the range for each numerical column. |
65 |
| - metadata (dict): |
66 |
| - The metadata dict. |
67 |
| -
|
68 |
| - Returns: |
69 |
| - float: |
70 |
| - Returns DCR value (the average value of DCR values we computed across the row). |
71 |
| - """ |
72 |
| - dcr_values = synthetic_row.index.to_series().apply( |
73 |
| - lambda synthetic_column_name: _calculate_dcr_value( |
74 |
| - synthetic_row[synthetic_column_name], |
75 |
| - comparison_row[synthetic_column_name], |
76 |
| - metadata['columns'][synthetic_column_name]['sdtype'], |
77 |
| - column_ranges.get(synthetic_column_name), |
78 |
| - ) |
| 47 | + full_dataset['diff'] = full_dataset.iloc[:, 2:].sum(axis=1) / len(cols_to_keep) |
| 48 | + chunk_result = ( |
| 49 | + full_dataset[['index_data', 'diff']].groupby('index_data').min().reset_index(drop=True) |
79 | 50 | )
|
80 |
| - |
81 |
| - return dcr_values.mean() |
| 51 | + return chunk_result['diff'] |
82 | 52 |
|
83 | 53 |
|
84 |
| -def _calculate_dcr_between_row_and_data(synthetic_row, real_data, column_ranges, metadata): |
85 |
| - """Calculate the DCR between a single row in the synthetic data and another dataset. |
| 54 | +def calculate_dcr(dataset, reference_dataset, metadata): |
| 55 | + """Calculate the Distance to Closest Record for all rows in the synthetic data. |
86 | 56 |
|
87 | 57 | Arguments:
|
88 |
| - synthetic_row (pandas.Series): |
89 |
| - The synthetic row that we are calculating DCR against an entire dataset. |
90 |
| - real_data (pandas.Dataframe): |
91 |
| - The dataset that acts as the reference for DCR calculations. |
92 |
| - column_ranges (dict): |
93 |
| - A dictionary that defines the range for each numerical column. |
| 58 | + dataset (pandas.Dataframe): |
| 59 | + The dataset for which we want to compute the DCR values |
| 60 | + reference_dataset (pandas.Dataframe): |
| 61 | + The reference dataset that is used for the distance computations |
94 | 62 | metadata (dict):
|
95 | 63 | The metadata dict.
|
96 | 64 |
|
97 | 65 | Returns:
|
98 |
| - float: |
99 |
| - Returns the minimum distance to closest record computed between the |
100 |
| - synthetic row and the reference dataset. |
| 66 | + pandas.Series: |
| 67 | + Returns a Series that shows the DCR value for every row of dataset |
101 | 68 | """
|
102 |
| - synthetic_distance_to_all_real = real_data.apply( |
103 |
| - lambda real_row: _calculate_dcr_between_rows( |
104 |
| - synthetic_row, real_row, column_ranges, metadata |
105 |
| - ), |
106 |
| - axis=1, |
107 |
| - ) |
108 |
| - return synthetic_distance_to_all_real.min() |
| 69 | + dataset_copy = _process_data_with_metadata(dataset.copy(), metadata, True) |
| 70 | + reference_copy = _process_data_with_metadata(reference_dataset.copy(), metadata, True) |
109 | 71 |
|
| 72 | + common_cols = set(dataset_copy.columns) & set(reference_copy.columns) |
| 73 | + cols_to_keep = [] |
| 74 | + ranges = {} |
110 | 75 |
|
111 |
| -def calculate_dcr(real_data, synthetic_data, metadata): |
112 |
| - """Calculate the Distance to Closest Record for all rows in the synthetic data. |
| 76 | + for col_name, col_metadata in get_columns_from_metadata(metadata).items(): |
| 77 | + sdtype = col_metadata['sdtype'] |
113 | 78 |
|
114 |
| - Arguments: |
115 |
| - real_data (pandas.Dataframe): |
116 |
| - The dataset that acts as the reference for DCR calculations. Ranges are determined from |
117 |
| - this dataset. |
118 |
| - synthetic_data (pandas.Dataframe): |
119 |
| - The synthetic data that we are calculating DCR values for. Every row will be measured |
120 |
| - against the comparison data. |
121 |
| - metadata (dict): |
122 |
| - The metadata dict. |
| 79 | + if ( |
| 80 | + sdtype in ['numerical', 'categorical', 'boolean', 'datetime'] |
| 81 | + and col_name in common_cols |
| 82 | + ): |
| 83 | + cols_to_keep.append(col_name) |
123 | 84 |
|
124 |
| - Returns: |
125 |
| - pandas.Series: |
126 |
| - Returns a Series that shows the DCR value for every row of synthetic data. |
127 |
| - """ |
128 |
| - column_ranges = {} |
| 85 | + if sdtype in ['numerical', 'datetime']: |
| 86 | + col_range = reference_copy[col_name].max() - reference_copy[col_name].min() |
| 87 | + if isinstance(col_range, pd.Timedelta): |
| 88 | + col_range = col_range.total_seconds() |
129 | 89 |
|
130 |
| - real_data_copy = real_data.copy() |
131 |
| - synthetic_data_copy = synthetic_data.copy() |
132 |
| - real_data_copy = _process_data_with_metadata(real_data_copy, metadata, True) |
133 |
| - synthetic_data_copy = _process_data_with_metadata(synthetic_data_copy, metadata, True) |
| 90 | + ranges[col_name] = col_range |
134 | 91 |
|
135 |
| - overlapping_columns = set(real_data_copy.columns) & set(synthetic_data_copy.columns) |
136 |
| - if not overlapping_columns: |
| 92 | + if not cols_to_keep: |
137 | 93 | raise ValueError('There are no overlapping statistical columns to measure.')
|
138 | 94 |
|
139 |
| - for col_name, column in get_columns_from_metadata(metadata).items(): |
140 |
| - sdtype = column['sdtype'] |
141 |
| - col_range = None |
142 |
| - if sdtype == 'numerical' or sdtype == 'datetime': |
143 |
| - col_range = real_data_copy[col_name].max() - real_data_copy[col_name].min() |
144 |
| - if isinstance(col_range, pd.Timedelta): |
145 |
| - col_range = col_range.total_seconds() |
146 |
| - |
147 |
| - column_ranges[col_name] = col_range |
148 |
| - |
149 |
| - dcr_dist_df = synthetic_data_copy.apply( |
150 |
| - lambda synth_row: _calculate_dcr_between_row_and_data( |
151 |
| - synth_row, real_data_copy, column_ranges, metadata |
152 |
| - ), |
153 |
| - axis=1, |
154 |
| - ) |
| 95 | + dataset_copy = dataset_copy[cols_to_keep] |
| 96 | + dataset_copy['index'] = range(len(dataset_copy)) |
| 97 | + |
| 98 | + reference_copy = reference_copy[cols_to_keep] |
| 99 | + reference_copy['index'] = range(len(reference_copy)) |
| 100 | + results = [] |
| 101 | + |
| 102 | + for chunk_start in range(0, len(dataset_copy), CHUNK_SIZE): |
| 103 | + chunk = dataset_copy.iloc[chunk_start : chunk_start + CHUNK_SIZE].copy() |
| 104 | + chunk_result = _process_dcr_chunk( |
| 105 | + chunk=chunk, |
| 106 | + reference_copy=reference_copy, |
| 107 | + cols_to_keep=cols_to_keep, |
| 108 | + metadata=metadata, |
| 109 | + ranges=ranges, |
| 110 | + ) |
| 111 | + results.append(chunk_result) |
| 112 | + |
| 113 | + result = pd.concat(results, ignore_index=True) |
| 114 | + result.name = None |
155 | 115 |
|
156 |
| - return dcr_dist_df |
| 116 | + return result |
0 commit comments