Skip to content

Commit b546a72

Browse files
add max_gm_samples param and subsample continuous columns before fitting GMs
1 parent 86fcd23 commit b546a72

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

ctgan/data_transformer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,27 @@ class DataTransformer(object):
1919
Discrete columns are encoded using a scikit-learn OneHotEncoder.
2020
"""
2121

22-
def __init__(self, max_clusters=10, weight_threshold=0.005):
22+
def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None):
2323
"""Create a data transformer.
2424
2525
Args:
2626
max_clusters (int):
2727
Maximum number of Gaussian distributions in Bayesian GMM.
2828
weight_threshold (float):
2929
Weight threshold for a Gaussian distribution to be kept.
30+
_max_gm_samples (int):
31+
Maximum number of sample to use during GMM fit
3032
"""
3133
self._max_clusters = max_clusters
3234
self._weight_threshold = weight_threshold
35+
self._max_gm_samples = np.inf if max_gm_samples is None else max_gm_samples
3336

3437
def _fit_continuous(self, column_name, raw_column_data):
3538
"""Train Bayesian GMM for continuous column."""
39+
if self._max_gm_samples <= raw_column_data.shape[0]:
40+
raw_column_data = np.random.choice(raw_column_data,
41+
size=self._max_gm_samples,
42+
replace=False)
3643
gm = BayesianGaussianMixture(
3744
self._max_clusters,
3845
weight_concentration_prior_type='dirichlet_process',

0 commit comments

Comments
 (0)