Skip to content
7 changes: 6 additions & 1 deletion sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sdv.logging import get_sdv_logger
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.single_table import INT_REGEX_ZERO_ERROR_MESSAGE, SingleTableMetadata
from sdv.metadata.utils import _validate_file_mode, read_json, validate_file_does_not_exist
from sdv.metadata.visualization import (
create_columns_node,
Expand Down Expand Up @@ -844,6 +844,11 @@ def _validate_all_tables(self, data):
self.tables[table_name].validate_data(table_data, table_sdtype_warnings)

except InvalidDataError as error:
if INT_REGEX_ZERO_ERROR_MESSAGE in str(error) and len(self.tables) > 1:
raise InvalidDataError([
f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}'
])

error_msg = f'Errors in {table_name}:'
for _error in error.errors:
error_msg += f'\nError: {_error}'
Expand Down
17 changes: 17 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_is_numerical_type,
_load_data_from_csv,
_validate_datetime_format,
get_possible_chars,
)
from sdv.errors import InvalidDataError
from sdv.logging import get_sdv_logger
Expand All @@ -35,6 +36,10 @@

LOGGER = logging.getLogger(__name__)
SINGLETABLEMETADATA_LOGGER = get_sdv_logger('SingleTableMetadata')
INT_REGEX_ZERO_ERROR_MESSAGE = (
'is stored as an int but the Regex allows it to start with "0". Please remove the Regex '
'or update it to correspond to valid ints.'
)


class SingleTableMetadata:
Expand Down Expand Up @@ -1185,6 +1190,17 @@ def _validate_key_values_are_unique(self, data):

return errors

def _validate_primary_key(self, data):
error = []
is_int = self.primary_key and pd.api.types.is_integer_dtype(data[self.primary_key])
regex = self.columns.get(self.primary_key, {}).get('regex_format')
if is_int and regex:
possible_characters = get_possible_chars(regex, 1)
if '0' in possible_characters:
error.append(f'Primary key "{self.primary_key}" {INT_REGEX_ZERO_ERROR_MESSAGE}')

return error

@staticmethod
def _get_invalid_column_values(column, validation_function):
valid = column.apply(validation_function).astype(bool)
Expand Down Expand Up @@ -1290,6 +1306,7 @@ def validate_data(self, data, sdtype_warnings=None):
for column in data:
errors += self._validate_column_data(data[column], sdtype_warnings)

errors += self._validate_primary_key(data)
if sdtype_warnings is not None and len(sdtype_warnings):
df = pd.DataFrame(sdtype_warnings)
message = (
Expand Down
56 changes: 26 additions & 30 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand Down Expand Up @@ -120,6 +119,8 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self._table_synthesizers = {}
self._table_parameters = defaultdict(dict)
self._original_table_columns = {}
self._original_metadata = deepcopy(self.metadata)
self.patterns = []
if synthesizer_kwargs is not None:
warn_message = (
'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not '
Expand All @@ -133,6 +134,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):

self._initialize_models()
self._fitted = False
self._constraints_fitted = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this parameter is not used at all for the sampling process right? I'm wondering if we need to worry about the backwards compatibility of this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. It's not currently used for sampling and I checked that all the tests are passing on enterprise when sdv points to this branch

self._creation_date = datetime.datetime.today().strftime('%Y-%m-%d')
self._fitted_date = None
self._fitted_sdv_version = None
Expand Down Expand Up @@ -166,19 +168,12 @@ def add_cag(self, patterns):
patterns (list):
A list of CAG patterns to apply to the synthesizer.
"""
if not hasattr(self, '_original_metadata'):
self._original_metadata = self.metadata

metadata = self.metadata
for pattern in patterns:
metadata = pattern.get_updated_metadata(metadata)

self.metadata = metadata
if hasattr(self, 'patterns'):
self.patterns += patterns
else:
self.patterns = patterns

self.patterns += patterns
self._initialize_models()

def get_cag(self):
Expand Down Expand Up @@ -233,27 +228,33 @@ def get_metadata(self, version='original'):

return Metadata.load_from_dict(self.metadata.to_dict())

def _transform_helper(self, data):
def _validate_transform_constraints(self, data, enforce_constraint_fitting=False):
"""Validate and transform all CAG patterns during preprocessing.

Args:
data (dict[str, pd.DataFrame]):
The data dictionary.
enforce_constraint_fitting (bool):
Whether to enforce fitting the constraints again. If set to ``True``, the
constraints will be fitted again even if they have already been fitted.
Defaults to ``False``.
"""
if not hasattr(self, 'patterns'):
if self._constraints_fitted and not enforce_constraint_fitting:
for pattern in self.patterns:
data = pattern.transform(data)

return data

metadata = self._original_metadata
for pattern in self.patterns:
if not self._fitted:
pattern.fit(data, metadata)
metadata = pattern.get_updated_metadata(metadata)

pattern.fit(data=data, metadata=metadata)
metadata = pattern.get_updated_metadata(metadata)
data = pattern.transform(data)

self._constraints_fitted = True
return data

def _reverse_transform_helper(self, sampled_data):
def _reverse_transform_constraints(self, sampled_data):
"""Reverse transform CAG patterns after sampling."""
if not hasattr(self, 'patterns'):
return sampled_data
Expand Down Expand Up @@ -359,7 +360,8 @@ def validate(self, data):
"""
errors = []
constraints_errors = []
self.metadata.validate_data(data)
metadata = self._original_metadata
metadata.validate_data(data)
for table_name in data:
if table_name in self._table_synthesizers:
try:
Expand All @@ -376,6 +378,8 @@ def validate(self, data):
elif errors:
raise InvalidDataError(errors)

self._validate_transform_constraints(data, enforce_constraint_fitting=True)

def _validate_table_name(self, table_name):
if table_name not in self._table_synthesizers:
raise ValueError(
Expand Down Expand Up @@ -476,8 +480,8 @@ def preprocess(self, data):
"""
list_of_changed_tables = self._store_and_convert_original_cols(data)

data = self._transform_helper(data)
self.validate(data)
data = self._validate_transform_constraints(data)
if self._fitted:
warnings.warn(
'This model has already been fitted. To use the new preprocessed data, '
Expand All @@ -487,17 +491,9 @@ def preprocess(self, data):
processed_data = {}
pbar_args = self._get_pbar_args(desc='Preprocess Tables')
for table_name, table_data in tqdm(data.items(), **pbar_args):
try:
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
except SynthesizerInputError as e:
if INT_REGEX_ZERO_ERROR_MESSAGE in str(e):
raise SynthesizerInputError(
f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}'
)

raise e
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)

for table in list_of_changed_tables:
data[table] = data[table].rename(columns=self._original_table_columns[table])
Expand Down Expand Up @@ -619,7 +615,7 @@ def sample(self, scale=1.0):

with self._set_temp_numpy_seed(), disable_single_table_logger():
sampled_data = self._sample(scale=scale)
sampled_data = self._reverse_transform_helper(sampled_data)
sampled_data = self._reverse_transform_constraints(sampled_data)

total_rows = 0
total_columns = 0
Expand Down
Loading
Loading