Skip to content

Commit 35ef9ce

Browse files
committed
Merge branch 'main' into dcr_feature_branch
2 parents 2941fcb + 2bd9681 commit 35ef9ce

File tree

13 files changed

+115
-171
lines changed

13 files changed

+115
-171
lines changed

sdmetrics/_utils_metadata.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pandas as pd
44

5+
from sdmetrics.utils import is_datetime
6+
57
MODELABLE_SDTYPES = ('numerical', 'datetime', 'categorical', 'boolean')
68

79

@@ -69,22 +71,34 @@ def wrapper(data, metadata):
6971
return wrapper
7072

7173

74+
def _convert_datetime_column(column_name, column_data, column_metadata):
75+
if is_datetime(column_data):
76+
return column_data
77+
78+
datetime_format = column_metadata.get('datetime_format')
79+
if datetime_format is None:
80+
raise ValueError(
81+
f"Datetime column '{column_name}' does not have a specified 'datetime_format'. "
82+
'Please add a the required datetime_format to the metadata or convert this column '
83+
"to 'pd.datetime' to bypass this requirement."
84+
)
85+
86+
try:
87+
pd.to_datetime(column_data, format=datetime_format)
88+
except Exception as e:
89+
raise ValueError(f"Error converting column '{column_name}' to timestamp: {e}")
90+
91+
return pd.to_datetime(column_data, format=datetime_format)
92+
93+
7294
@handle_single_and_multi_table
7395
def _convert_datetime_columns(data, metadata):
7496
"""Convert datetime columns to datetime type."""
7597
for column in metadata['columns']:
7698
if metadata['columns'][column]['sdtype'] == 'datetime':
77-
is_datetime = pd.api.types.is_datetime64_any_dtype(data[column])
78-
if not is_datetime:
79-
datetime_format = metadata['columns'][column].get('datetime_format')
80-
if datetime_format:
81-
data[column] = pd.to_datetime(data[column], format=datetime_format)
82-
else:
83-
raise ValueError(
84-
f"Datetime column '{column}' does not have a specified 'datetime_format'. "
85-
'Please add a the required datetime_format to the metadata or convert this column '
86-
"to 'pd.datetime' to bypass this requirement."
87-
)
99+
data[column] = _convert_datetime_column(
100+
column, data[column], metadata['columns'][column]
101+
)
88102

89103
return data
90104

sdmetrics/reports/base_report.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import pandas as pd
1414
import tqdm
1515

16-
from sdmetrics._utils_metadata import _validate_metadata
17-
from sdmetrics.reports.utils import convert_datetime_columns
16+
from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata
1817
from sdmetrics.visualization import set_plotly_config
1918

2019

@@ -101,8 +100,8 @@ def convert_datetimes(real_data, synthetic_data, metadata):
101100
real_col = real_data[column]
102101
synth_col = synthetic_data[column]
103102
try:
104-
converted_cols = convert_datetime_columns(real_col, synth_col, col_meta)
105-
real_data[column], synthetic_data[column] = converted_cols
103+
real_data[column] = _convert_datetime_column(column, real_col, col_meta)
104+
synthetic_data[column] = _convert_datetime_column(column, synth_col, col_meta)
106105
except Exception:
107106
continue
108107

sdmetrics/reports/single_table/_properties/column_pair_trends.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from plotly import graph_objects as go
66
from plotly.subplots import make_subplots
77

8+
from sdmetrics._utils_metadata import _convert_datetime_column
89
from sdmetrics.column_pairs.statistical import ContingencySimilarity, CorrelationSimilarity
910
from sdmetrics.reports.single_table._properties import BaseSingleTableProperty
1011
from sdmetrics.reports.utils import PlotConfig
11-
from sdmetrics.utils import is_datetime
1212

1313
DEFAULT_NUM_ROWS_SUBSAMPLE = 50000
1414

@@ -51,13 +51,9 @@ def _convert_datetime_columns_to_numeric(self, data, metadata):
5151
col_sdtype = column_meta['sdtype']
5252
try:
5353
if col_sdtype == 'datetime':
54-
if not is_datetime(data[column_name]):
55-
datetime_format = column_meta.get(
56-
'datetime_format', column_meta.get('format')
57-
)
58-
data[column_name] = pd.to_datetime(
59-
data[column_name], format=datetime_format
60-
)
54+
data[column_name] = _convert_datetime_column(
55+
column_name, data[column_name], column_meta
56+
)
6157
nan_mask = pd.isna(data[column_name])
6258
data[column_name] = pd.to_numeric(data[column_name])
6359
if nan_mask.any():

sdmetrics/reports/utils.py

+3-51
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66

77
import numpy as np
88
import pandas as pd
9-
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
109

10+
from sdmetrics._utils_metadata import _convert_datetime_column
1111
from sdmetrics.utils import (
1212
discretize_column,
1313
get_alternate_keys,
1414
get_columns_from_metadata,
1515
get_type_from_column_meta,
16-
is_datetime,
1716
)
1817

1918
CONTINUOUS_SDTYPES = ['numerical', 'datetime']
@@ -35,51 +34,6 @@ class PlotConfig:
3534
FONT_SIZE = 18
3635

3736

38-
def convert_to_datetime(column_data, datetime_format=None):
39-
"""Convert a column data to pandas datetime.
40-
41-
Args:
42-
column_data (pandas.Series):
43-
The column data
44-
format (str):
45-
Optional string format of datetime. If ``None``, will attempt to infer the datetime
46-
format from the column data. Defaults to ``None``.
47-
48-
Returns:
49-
pandas.Series:
50-
The converted column data.
51-
"""
52-
if is_datetime(column_data):
53-
return column_data
54-
55-
if datetime_format is None:
56-
datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy())
57-
58-
return pd.to_datetime(column_data, format=datetime_format)
59-
60-
61-
def convert_datetime_columns(real_column, synthetic_column, col_metadata):
62-
"""Convert a real and a synthetic column to pandas datetime.
63-
64-
Args:
65-
real_data (pandas.Series):
66-
The real column data
67-
synthetic_column (pandas.Series):
68-
The synthetic column data
69-
col_metadata:
70-
The metadata associated with the column
71-
72-
Returns:
73-
(pandas.Series, pandas.Series):
74-
The converted real and synthetic column data.
75-
"""
76-
datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format')
77-
return (
78-
convert_to_datetime(real_column, datetime_format),
79-
convert_to_datetime(synthetic_column, datetime_format),
80-
)
81-
82-
8337
def discretize_table_data(real_data, synthetic_data, metadata):
8438
"""Create a copy of the real and synthetic data with discretized data.
8539
@@ -109,10 +63,8 @@ def discretize_table_data(real_data, synthetic_data, metadata):
10963
real_col = real_data[column_name]
11064
synthetic_col = synthetic_data[column_name]
11165
if sdtype == 'datetime':
112-
datetime_format = column_meta.get('format') or column_meta.get('datetime_format')
113-
if real_col.dtype == 'O' and datetime_format:
114-
real_col = pd.to_datetime(real_col, format=datetime_format)
115-
synthetic_col = pd.to_datetime(synthetic_col, format=datetime_format)
66+
real_col = _convert_datetime_column(column_name, real_col, column_meta)
67+
synthetic_col = _convert_datetime_column(column_name, synthetic_col, column_meta)
11668

11769
real_col = pd.to_numeric(real_col)
11870
synthetic_col = pd.to_numeric(synthetic_col)

sdmetrics/single_table/base.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
import copy
44
from operator import attrgetter
55

6-
import pandas as pd
7-
8-
from sdmetrics._utils_metadata import _validate_single_table_metadata
6+
from sdmetrics._utils_metadata import _convert_datetime_column, _validate_single_table_metadata
97
from sdmetrics.base import BaseMetric
108
from sdmetrics.errors import IncomputableMetricError
119
from sdmetrics.utils import get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta
@@ -119,14 +117,12 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None):
119117
field_type = get_type_from_column_meta(field_meta)
120118
if field not in real_data.columns:
121119
raise ValueError(f'Field {field} not found in data')
122-
if (
123-
field_type == 'datetime'
124-
and 'datetime_format' in field_meta
125-
and real_data[field].dtype == 'O'
126-
):
127-
dt_format = field_meta['datetime_format']
128-
real_data[field] = pd.to_datetime(real_data[field], format=dt_format)
129-
synthetic_data[field] = pd.to_datetime(synthetic_data[field], format=dt_format)
120+
121+
if field_type == 'datetime':
122+
real_data[field] = _convert_datetime_column(field, real_data[field], field_meta)
123+
synthetic_data[field] = _convert_datetime_column(
124+
field, synthetic_data[field], field_meta
125+
)
130126

131127
return real_data, synthetic_data, metadata
132128

sdmetrics/single_table/new_row_synthesis.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pandas as pd
66

7+
from sdmetrics._utils_metadata import _convert_datetime_column
78
from sdmetrics.errors import IncomputableMetricError
89
from sdmetrics.goal import Goal
910
from sdmetrics.single_table.base import SingleTableMetric
@@ -83,9 +84,10 @@ def compute_breakdown(
8384

8485
for field, field_meta in get_columns_from_metadata(metadata).items():
8586
if get_type_from_column_meta(field_meta) == 'datetime':
86-
if len(real_data[field]) > 0 and isinstance(real_data[field][0], str):
87-
real_data[field] = pd.to_datetime(real_data[field])
88-
synthetic_data[field] = pd.to_datetime(synthetic_data[field])
87+
real_data[field] = _convert_datetime_column(field, real_data[field], field_meta)
88+
synthetic_data[field] = _convert_datetime_column(
89+
field, synthetic_data[field], field_meta
90+
)
8991

9092
real_data[field] = pd.to_numeric(real_data[field])
9193
synthetic_data[field] = pd.to_numeric(synthetic_data[field])

sdmetrics/timeseries/base.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
from operator import attrgetter
44

5-
import pandas as pd
6-
7-
from sdmetrics._utils_metadata import _validate_metadata_dict
5+
from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata_dict
86
from sdmetrics.base import BaseMetric
97
from sdmetrics.utils import get_columns_from_metadata
108

@@ -62,18 +60,14 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None, sequence_key
6260
if field not in real_data.columns:
6361
raise ValueError(f'Field {field} not found in data')
6462

65-
for column, kwargs in metadata['columns'].items():
66-
if kwargs['sdtype'] == 'datetime':
67-
datetime_format = kwargs.get('datetime_format')
68-
try:
69-
real_data[column] = pd.to_datetime(
70-
real_data[column], format=datetime_format
71-
)
72-
synthetic_data[column] = pd.to_datetime(
73-
synthetic_data[column], format=datetime_format
74-
)
75-
except ValueError:
76-
raise ValueError(f"Column '{column}' is not a valid datetime")
63+
for column, col_metadata in metadata['columns'].items():
64+
if col_metadata['sdtype'] == 'datetime':
65+
real_data[column] = _convert_datetime_column(
66+
column, real_data[column], col_metadata
67+
)
68+
synthetic_data[column] = _convert_datetime_column(
69+
column, synthetic_data[column], col_metadata
70+
)
7771

7872
else:
7973
dtype_kinds = real_data.dtypes.apply(attrgetter('kind'))

tests/integration/timeseries/test_timeseries.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,22 @@ def test_compute_lstmdetection_multiple_categorical_columns():
6363
def test_compute_lstmdetection_mismatching_datetime_columns():
6464
"""Test LSTMDetection metric with mismatching datetime columns.
6565
66-
Test it when the real data has a date column and the synthetic data has a string column.
66+
Test it when the real data has a datetime column and the synthetic data has a string column.
6767
"""
6868
# Setup
6969
df1 = pd.DataFrame({
7070
's_key': [1, 2, 3, 4, 5],
7171
'visits': pd.to_datetime(['1/1/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019']),
7272
})
73-
df1['visits'] = df1['visits'].dt.date
7473
df2 = pd.DataFrame({
7574
's_key': [1, 2, 3, 4, 5],
7675
'visits': ['1/2/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019'],
7776
})
7877
metadata = {
79-
'columns': {'s_key': {'sdtype': 'numerical'}, 'visits': {'sdtype': 'datetime'}},
78+
'columns': {
79+
's_key': {'sdtype': 'numerical'},
80+
'visits': {'sdtype': 'datetime', 'datetime_format': '%m/%d/%Y'},
81+
},
8082
'sequence_key': 's_key',
8183
}
8284

tests/unit/reports/multi_table/test_base_multi_table_report.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ def test_convert_datetimes(self):
209209
metadata = {
210210
'tables': {
211211
'table1': {
212-
'columns': {'col1': {'sdtype': 'datetime'}, 'col2': {'sdtype': 'datetime'}},
212+
'columns': {
213+
'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
214+
'col2': {'sdtype': 'datetime'},
215+
},
213216
},
214217
},
215218
}

tests/unit/reports/test_base_report.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ def test_convert_datetimes(self):
174174
real_data = pd.DataFrame({'col1': ['2020-01-02', '2021-01-02'], 'col2': ['a', 'b']})
175175
synthetic_data = pd.DataFrame({'col1': ['2022-01-03', '2023-04-05'], 'col2': ['b', 'a']})
176176
metadata = {
177-
'columns': {'col1': {'sdtype': 'datetime'}, 'col2': {'sdtype': 'datetime'}},
177+
'columns': {
178+
'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
179+
'col2': {'sdtype': 'datetime'},
180+
},
178181
}
179182

180183
# Run

0 commit comments

Comments
 (0)