@@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
267
267
if invalid_columns :
268
268
raise ValueError ('Invalid columns found: {}' .format (invalid_columns ))
269
269
270
- def fit (self , train_data , discrete_columns = tuple (), epochs = None ):
270
+ def fit (self , train_data , discrete_columns = tuple (), epochs = None ,
271
+ data_transformer_params = {}):
271
272
"""Fit the CTGAN Synthesizer models to the training data.
272
273
273
274
Args:
@@ -278,6 +279,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
278
279
Vector. If ``train_data`` is a Numpy array, this list should
279
280
contain the integer indices of the columns. Otherwise, if it is
280
281
a ``pandas.DataFrame``, this list should contain the column names.
282
+ data_transformer_params (dict):
283
+ Dictionary of parameters for ``DataTransformer`` initialization.
281
284
"""
282
285
self ._validate_discrete_columns (train_data , discrete_columns )
283
286
@@ -290,7 +293,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
290
293
DeprecationWarning
291
294
)
292
295
293
- self ._transformer = DataTransformer ()
296
+ self ._transformer = DataTransformer (** data_transformer_params )
294
297
self ._transformer .fit (train_data , discrete_columns )
295
298
296
299
train_data = self ._transformer .transform (train_data )
0 commit comments