Skip to content

Commit 202aa7b

Browse files
committed
def + test multi-table
1 parent 21a93ac commit 202aa7b

File tree

3 files changed

+87
-14
lines changed

3 files changed

+87
-14
lines changed

sdv/multi_table/base.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
check_synthesizer_version,
1919
generate_synthesizer_id,
2020
)
21-
from sdv.cag._errors import PatternNotMetError
2221
from sdv.errors import (
2322
ConstraintsNotMetError,
2423
InvalidDataError,
@@ -118,6 +117,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
118117
self._table_synthesizers = {}
119118
self._table_parameters = defaultdict(dict)
120119
self._original_table_columns = {}
120+
self._original_metadata = deepcopy(self.metadata)
121121
if synthesizer_kwargs is not None:
122122
warn_message = (
123123
'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not '
@@ -323,6 +323,22 @@ def _validate_all_tables(self, data):
323323

324324
return errors
325325

326+
def _validate_cags(self, data):
327+
"""Validate the data against the CAG patterns.
328+
329+
Args:
330+
data (pandas.DataFrame):
331+
The data to validate.
332+
"""
333+
metadata = self.metadata
334+
if hasattr(self, '_original_metadata'):
335+
metadata = self._original_metadata
336+
337+
if hasattr(self, 'patterns'):
338+
for pattern in self.patterns:
339+
pattern.validate(data=data, metadata=metadata)
340+
metadata = pattern.get_updated_metadata(metadata)
341+
326342
def validate(self, data):
327343
"""Validate the data.
328344
@@ -334,8 +350,11 @@ def validate(self, data):
334350
"""
335351
errors = []
336352
constraints_errors = []
337-
cags_errors = []
338-
self.metadata.validate_data(data)
353+
metadata = self.metadata
354+
if hasattr(self, '_original_metadata'):
355+
metadata = self._original_metadata
356+
357+
metadata.validate_data(data)
339358
for table_name in data:
340359
if table_name in self._table_synthesizers:
341360
try:
@@ -346,21 +365,13 @@ def validate(self, data):
346365
# Validate rules specific to each synthesizer
347366
errors += self._table_synthesizers[table_name]._validate(data[table_name])
348367

349-
# Validate single-table cags
350-
if hasattr(self._table_synthesizers[table_name], '_validate_cags'):
351-
try:
352-
self._table_synthesizers[table_name]._validate_cags(data[table_name])
353-
except PatternNotMetError as error:
354-
cags_errors.append(error)
355-
356368
if constraints_errors:
357369
raise ConstraintsNotMetError(constraints_errors)
358370

359371
elif errors:
360372
raise InvalidDataError(errors)
361373

362-
elif cags_errors:
363-
raise PatternNotMetError(cags_errors)
374+
self._validate_cags(data)
364375

365376
def _validate_table_name(self, table_name):
366377
if table_name not in self._table_synthesizers:
@@ -462,8 +473,8 @@ def preprocess(self, data):
462473
"""
463474
list_of_changed_tables = self._store_and_convert_original_cols(data)
464475

465-
data = self._transform_helper(data)
466476
self.validate(data)
477+
data = self._transform_helper(data)
467478
if self._fitted:
468479
warnings.warn(
469480
'This model has already been fitted. To use the new preprocessed data, '

tests/integration/multi_table/test_hma.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from sdmetrics.reports.multi_table import DiagnosticReport
1616

1717
from sdv import version
18+
from sdv.cag import Inequality
19+
from sdv.cag._errors import PatternNotMetError
1820
from sdv.datasets.demo import download_demo
1921
from sdv.datasets.local import load_csvs
2022
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError, VersionError
@@ -2686,3 +2688,35 @@ def test__unsupported_regex_format():
26862688
# Run and Assert
26872689
with pytest.raises(SynthesizerInputError, match=expected_error):
26882690
HMASynthesizer(metadata)
2691+
2692+
2693+
def test_end_to_end_with_cags():
2694+
"""Test HMA with a single-table cag."""
2695+
# Setup
2696+
data, metadata = download_demo('multi_table', 'fake_hotels')
2697+
synthesizer = HMASynthesizer(metadata)
2698+
pattern = Inequality(
2699+
low_column_name='checkin_date',
2700+
high_column_name='checkout_date',
2701+
strict_boundaries=False,
2702+
table_name='guests',
2703+
)
2704+
synthesizer.add_cag(patterns=[pattern])
2705+
data_guests = data['guests']
2706+
clean_data = data_guests[~(data_guests[['checkin_date', 'checkout_date']].isna().any(axis=1))]
2707+
data_invalid = clean_data.copy()
2708+
data_invalid.loc[0, 'checkin_date'] = '31 Dec 2020'
2709+
data['guests'] = clean_data
2710+
invalid_data = data.copy()
2711+
invalid_data['guests'] = data_invalid
2712+
expected_error_msg = re.escape('The inequality requirement is not met for row indices: [0]')
2713+
2714+
# Run
2715+
synthesizer.fit(data)
2716+
synthetic_data = synthesizer.sample(scale=1.0)
2717+
2718+
with pytest.raises(PatternNotMetError, match=expected_error_msg):
2719+
synthesizer.fit(invalid_data)
2720+
2721+
# Assert
2722+
synthesizer.validate(synthetic_data)

tests/unit/multi_table/test_base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,16 +410,44 @@ def test_get_metadata(self):
410410
assert type(result) is Metadata
411411
assert expected_metadata.to_dict() == result.to_dict()
412412

413+
def test__validate_cags(self):
414+
"""Test the ``_validate_cags`` method."""
415+
# Setup
416+
data = pd.DataFrame()
417+
original_metadata = Metadata()
418+
metadata_1 = Metadata()
419+
metadata_2 = Metadata()
420+
instance = BaseMultiTableSynthesizer(original_metadata)
421+
cag_mock_1 = Mock()
422+
cag_mock_1.get_updated_metadata = Mock(return_value=metadata_1)
423+
cag_mock_2 = Mock()
424+
cag_mock_2.get_updated_metadata = Mock(return_value=metadata_2)
425+
cag_mock_3 = Mock()
426+
instance.patterns = [cag_mock_1, cag_mock_2, cag_mock_3]
427+
428+
# Run
429+
instance._validate_cags(data)
430+
431+
# Assert
432+
cag_mock_1.get_updated_metadata.assert_called_once_with(instance._original_metadata)
433+
cag_mock_1.validate.assert_called_once_with(data=data, metadata=instance._original_metadata)
434+
cag_mock_2.validate.assert_called_once_with(data=data, metadata=metadata_1)
435+
cag_mock_3.validate.assert_called_once_with(data=data, metadata=metadata_2)
436+
413437
def test_validate(self):
414438
"""Test that no error is being raised when the data is valid."""
415439
# Setup
416440
metadata = get_multi_table_metadata()
417441
data = get_multi_table_data()
418442
instance = BaseMultiTableSynthesizer(metadata)
443+
instance._validate_cags = Mock()
419444

420-
# Run and Assert
445+
# Run
421446
instance.validate(data)
422447

448+
# Assert
449+
instance._validate_cags.assert_called_once_with(data)
450+
423451
def test_validate_missing_table(self):
424452
"""Test that an error is being raised when there is a missing table in the dictionary."""
425453
# Setup

0 commit comments

Comments
 (0)