Skip to content

Commit 61cb10e

Browse files
committed
move out _convert_original_cols from _preprocess_helper
1 parent 739d31b commit 61cb10e

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

sdv/single_table/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,15 +547,14 @@ def _preprocess_helper(self, data):
547547
- Warn the user if the model has already been fitted
548548
- Store the original columns and convert them to string if needed
549549
"""
550-
is_converted = self._store_and_convert_original_cols(data)
551550
self.validate(data)
552551
if self._fitted:
553552
warnings.warn(
554553
'This model has already been fitted. To use the new preprocessed data, '
555554
"please refit the model using 'fit' or 'fit_processed_data'."
556555
)
557556

558-
return data, is_converted
557+
return data
559558

560559
def preprocess(self, data):
561560
"""Transform the raw data to numerical space.
@@ -568,7 +567,8 @@ def preprocess(self, data):
568567
pandas.DataFrame:
569568
The preprocessed data.
570569
"""
571-
data, is_converted = self._preprocess_helper(data)
570+
is_converted = self._store_and_convert_original_cols(data)
571+
data = self._preprocess_helper(data)
572572
preprocess_data = self._preprocess(data)
573573
if is_converted:
574574
data.columns = self._original_columns
@@ -856,10 +856,10 @@ def validate(self, data):
856856
self.metadata = metadata
857857

858858
def _preprocess_helper(self, data):
859-
data, is_converted = super()._preprocess_helper(data)
859+
data = super()._preprocess_helper(data)
860860
data = self._transform_helper(data)
861861

862-
return data, is_converted
862+
return data
863863

864864
def _set_random_state(self, random_state):
865865
"""Set the random state of the model's random number generator.

tests/unit/single_table/test_base.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ def test__preprocess_helper_basesynthesizer(self, mock_warnings):
505505
instance = Mock()
506506
instance._fitted = True
507507
data = pd.DataFrame({'name': ['John', 'Doe', 'John Doe']})
508-
instance._store_and_convert_original_cols.return_value = False
509508

510509
# Run
511510
result = BaseSynthesizer._preprocess_helper(instance, data)
@@ -517,8 +516,7 @@ def test__preprocess_helper_basesynthesizer(self, mock_warnings):
517516
)
518517
mock_warnings.warn.assert_called_once_with(expected_warning)
519518
instance.validate.assert_called_once_with(data)
520-
pd.testing.assert_frame_equal(result[0], data)
521-
assert result[1] is False
519+
pd.testing.assert_frame_equal(result, data)
522520

523521
@patch.object(BaseSynthesizer, '_preprocess_helper')
524522
def test__preprocess_helper(self, mock_preprocess_helper):
@@ -536,14 +534,13 @@ def test__preprocess_helper(self, mock_preprocess_helper):
536534
instance = BaseSingleTableSynthesizer(metadata)
537535
data = pd.DataFrame({'name': ['John', 'Doe', 'John Doe']})
538536
instance._transform_helper = Mock(return_value=data)
539-
mock_preprocess_helper.return_value = (data, False)
537+
mock_preprocess_helper.return_value = data
540538

541539
# Run
542540
result = instance._preprocess_helper(data)
543541

544542
# Assert
545-
pd.testing.assert_frame_equal(result[0], data)
546-
assert result[1] is False
543+
pd.testing.assert_frame_equal(result, data)
547544
instance._transform_helper.assert_called_once_with(data)
548545

549546
def test__preprocess(self):
@@ -569,22 +566,20 @@ def test__preprocess(self):
569566
)
570567

571568
def test_preprocess(self):
572-
"""Test the preprocess method.
573-
574-
The preprocess method raises a warning if it was already fitted and then calls
575-
``_preprocess``.
576-
"""
569+
"""Test the preprocess method."""
577570
# Setup
578571
instance = Mock()
579572
instance._fitted = True
580573
data = pd.DataFrame({'name': ['John', 'Doe', 'John Doe']})
581-
instance._preprocess_helper.return_value = (data, False)
574+
instance._preprocess_helper.return_value = data
575+
instance._store_and_convert_original_cols = Mock(return_value=False)
582576
instance._preprocess.return_value = data
583577

584578
# Run
585579
result = BaseSingleTableSynthesizer.preprocess(instance, data)
586580

587581
# Assert
582+
instance._store_and_convert_original_cols.assert_called_once_with(data)
588583
instance._preprocess_helper.assert_called_once_with(data)
589584
instance._preprocess.assert_called_once_with(data)
590585
pd.testing.assert_frame_equal(result, data)

0 commit comments

Comments
 (0)