Skip to content

Commit 4560e78

Browse files
committed
add data_transformer_params to have control over data_transformer
1 parent 4b0d505 commit 4560e78

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

ctgan/synthesizers/ctgan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
267267
if invalid_columns:
268268
raise ValueError('Invalid columns found: {}'.format(invalid_columns))
269269

270-
def fit(self, train_data, discrete_columns=tuple(), epochs=None):
270+
def fit(self, train_data, discrete_columns=tuple(), epochs=None,
271+
data_transformer_params={}):
271272
"""Fit the CTGAN Synthesizer models to the training data.
272273
273274
Args:
@@ -278,6 +279,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
278279
Vector. If ``train_data`` is a Numpy array, this list should
279280
contain the integer indices of the columns. Otherwise, if it is
280281
a ``pandas.DataFrame``, this list should contain the column names.
282+
data_transformer_params (dict):
283+
Dictionary of parameters for ``DataTransformer`` initialization.
281284
"""
282285
self._validate_discrete_columns(train_data, discrete_columns)
283286

@@ -290,7 +293,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
290293
DeprecationWarning
291294
)
292295

293-
self._transformer = DataTransformer()
296+
self._transformer = DataTransformer(**data_transformer_params)
294297
self._transformer.fit(train_data, discrete_columns)
295298

296299
train_data = self._transformer.transform(train_data)

0 commit comments

Comments
 (0)