Skip to content

Commit a4c4d5b

Browse files
add test to check max_gm_samples
1 parent b2ced54 commit a4c4d5b

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/integration/test_ctgan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,14 @@ def test_wrong_sampling_conditions():
184184

185185
with pytest.raises(ValueError):
186186
ctgan.sample(1, 'discrete', "d")
187+
188+
189+
def test_ctgan_data_transformer_params():
190+
data = pd.DataFrame({
191+
'continuous': np.random.random(1000)
192+
})
193+
194+
ctgan = CTGANSynthesizer(epochs=1)
195+
ctgan.fit(data, [], data_transformer_params={'max_gm_samples': 100})
196+
197+
assert ctgan._transformer._max_gm_samples == 100

0 commit comments

Comments
 (0)