Skip to content

Commit 7bca5fa

Browse files
authored
Add integration tests for DataTransformer, remove unnecessary code and fix max_clusters bug (#314)
* Remove unnecessary tests * Lots of minor changes * Fix lint * Fix lint * Fix lint... * Add weight parameter * Fix test * Fix typos
1 parent 2848a42 commit 7bca5fa

File tree

7 files changed

+154
-312
lines changed

7 files changed

+154
-312
lines changed

ctgan/data_transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
class DataTransformer(object):
1919
"""Data Transformer.
2020
21-
Model continuous columns with a BayesianGMM and normalized to a scalar [0, 1] and a vector.
22-
Discrete columns are encoded using a scikit-learn OneHotEncoder.
21+
Model continuous columns with a BayesianGMM and normalize them to a scalar between [-1, 1]
22+
and a vector. Discrete columns are encoded using a OneHotEncoder.
2323
"""
2424

2525
def __init__(self, max_clusters=10, weight_threshold=0.005):
@@ -47,7 +47,10 @@ def _fit_continuous(self, data):
4747
"""
4848
column_name = data.columns[0]
4949
gm = ClusterBasedNormalizer(
50-
missing_value_generation='from_column', max_clusters=min(len(data), 10))
50+
missing_value_generation='from_column',
51+
max_clusters=min(len(data), self._max_clusters),
52+
weight_threshold=self._weight_threshold
53+
)
5154
gm.fit(data, column_name)
5255
num_components = sum(gm.valid_component_indicator)
5356

tests/integration/synthesizer/test_ctgan.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -231,56 +231,6 @@ def test_fixed_random_seed():
231231
np.testing.assert_array_equal(sampled_0_1, sampled_1_1)
232232

233233

234-
# Below are CTGAN tests that should be implemented in the future
235-
def test_continuous():
236-
"""Test training the CTGAN synthesizer on a continuous dataset."""
237-
# assert the distribution of the samples is close to the distribution of the data
238-
# using kstest:
239-
# - uniform (assert p-value > 0.05)
240-
# - gaussian (assert p-value > 0.05)
241-
# - inversely correlated (assert correlation < 0)
242-
pass
243-
244-
245-
def test_categorical():
246-
"""Test training the CTGAN synthesizer on a categorical dataset."""
247-
# assert the distribution of the samples is close to the distribution of the data
248-
# using cstest:
249-
# - uniform (assert p-value > 0.05)
250-
# - very skewed / biased? (assert p-value > 0.05)
251-
# - inversely correlated (assert correlation < 0)
252-
pass
253-
254-
255-
def test_categorical_log_frequency():
256-
"""Test training the CTGAN synthesizer on a small categorical dataset."""
257-
# assert the distribution of the samples is close to the distribution of the data
258-
# using cstest:
259-
# - uniform (assert p-value > 0.05)
260-
# - very skewed / biased? (assert p-value > 0.05)
261-
# - inversely correlated (assert correlation < 0)
262-
pass
263-
264-
265-
def test_mixed():
266-
"""Test training the CTGAN synthesizer on a small mixed-type dataset."""
267-
# assert the distribution of the samples is close to the distribution of the data
268-
# using a kstest for continuous + a cstest for categorical.
269-
pass
270-
271-
272-
def test_conditional():
273-
"""Test training the CTGAN synthesizer and sampling conditioned on a categorical."""
274-
# verify that conditioning increases the likelihood of getting a sample with the specified
275-
# categorical value
276-
pass
277-
278-
279-
def test_batch_size_pack_size():
280-
"""Test that if batch size is not a multiple of pack size, it raises a sane error."""
281-
pass
282-
283-
284234
def test_ctgan_save_and_load(tmpdir):
285235
"""Test that the ``CTGAN`` model can be saved and loaded."""
286236
# Setup

tests/integration/synthesizer/test_tvae.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,6 @@ def test_drop_last_false():
5656
assert correct >= 95
5757

5858

59-
# TVAE tests that should be implemented in the future.
60-
def test_continuous():
61-
"""Test training the TVAE synthesizer on a small continuous dataset."""
62-
# verify that the distribution of the samples is close to the distribution of the data
63-
# using a kstest.
64-
pass
65-
66-
67-
def test_categorical():
68-
"""Test training the TVAE synthesizer on a small categorical dataset."""
69-
# verify that the distribution of the samples is close to the distribution of the data
70-
# using a cstest.
71-
pass
72-
73-
74-
def test_mixed():
75-
"""Test training the TVAE synthesizer on a small mixed-type dataset."""
76-
# verify that the distribution of the samples is close to the distribution of the data
77-
# using a kstest for continuous + a cstest for categorical.
78-
pass
79-
80-
8159
def test__loss_function():
8260
"""Test the TVAE produces average values similar to the training data."""
8361
data = pd.DataFrame({
Lines changed: 144 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,146 @@
11
"""Data transformer intergration testing module."""
22

3-
4-
# Data Transformer tests that should be implemented in the future.
5-
def test_constant():
6-
"""Test transforming a dataframe containing constant values."""
7-
8-
9-
def test_df_continuous():
10-
"""Test transforming a dataframe containing only continuous values."""
11-
# validate output ranges [0, 1]
12-
# validate output shape (# samples, # output dims)
13-
# validate that forward transform is **not** deterministic
14-
# make sure it can be inverted
15-
16-
17-
def test_df_categorical():
18-
"""Test transforming a dataframe containing only categorical values."""
19-
# validate output ranges [0, 1]
20-
# validate output shape (# samples, # output dims)
21-
# validate that forward transform is deterministic
22-
# make sure it can be inverted
23-
24-
25-
def test_df_mixed():
26-
"""Test transforming a dataframe containing mixed data types."""
27-
28-
29-
def test_df_mixed_nan():
30-
"""Test transforming a dataframe containing mixed data types + NaN for categoricals."""
31-
32-
33-
def test_np_continuous():
34-
"""Test transforming a np.array containing only continuous values."""
35-
36-
37-
def test_np_categorical():
38-
"""Test transforming a np.array containing only categorical values."""
39-
40-
41-
def test_np_mixed():
42-
"""Test transforming a np.array containing mixed data types."""
3+
from unittest import TestCase
4+
5+
import numpy as np
6+
import pandas as pd
7+
8+
from ctgan.data_transformer import DataTransformer
9+
10+
11+
class TestDataTransformer(TestCase):
12+
13+
def test_constant(self):
14+
"""Test transforming a dataframe containing constant values."""
15+
# Setup
16+
data = pd.DataFrame({'cnt': [123] * 1000})
17+
transformer = DataTransformer()
18+
19+
# Run
20+
transformer.fit(data, [])
21+
new_data = transformer.transform(data)
22+
transformer.inverse_transform(new_data)
23+
24+
# Assert transformed values are between -1 and 1
25+
assert (new_data[:, 0] > -np.ones(len(new_data))).all()
26+
assert (new_data[:, 0] < np.ones(len(new_data))).all()
27+
28+
# Assert transformed values are a gaussian centered in 0 and with std ~ 0
29+
assert -.1 < np.mean(new_data[:, 0]) < .1
30+
assert 0 <= np.std(new_data[:, 0]) < .1
31+
32+
# Assert there are at most `max_columns=10` one hot columns
33+
assert new_data.shape[0] == 1000
34+
assert new_data.shape[1] <= 11
35+
assert np.isin(new_data[:, 1:], [0, 1]).all()
36+
37+
def test_df_continuous(self):
38+
"""Test transforming a dataframe containing only continuous values."""
39+
# Setup
40+
data = pd.DataFrame({'col': np.random.normal(size=1000)})
41+
transformer = DataTransformer()
42+
43+
# Run
44+
transformer.fit(data, [])
45+
new_data = transformer.transform(data)
46+
transformer.inverse_transform(new_data)
47+
48+
# Assert transformed values are between -1 and 1
49+
assert (new_data[:, 0] > -np.ones(len(new_data))).all()
50+
assert (new_data[:, 0] < np.ones(len(new_data))).all()
51+
52+
# Assert transformed values are a gaussian centered in 0 and with std = 1/4
53+
assert -.1 < np.mean(new_data[:, 0]) < .1
54+
assert .2 < np.std(new_data[:, 0]) < .3
55+
56+
# Assert there are at most `max_columns=10` one hot columns
57+
assert new_data.shape[0] == 1000
58+
assert new_data.shape[1] <= 11
59+
assert np.isin(new_data[:, 1:], [0, 1]).all()
60+
61+
def test_df_categorical_constant(self):
62+
"""Test transforming a dataframe containing only constant categorical values."""
63+
# Setup
64+
data = pd.DataFrame({'cnt': [123] * 1000})
65+
transformer = DataTransformer()
66+
67+
# Run
68+
transformer.fit(data, ['cnt'])
69+
new_data = transformer.transform(data)
70+
transformer.inverse_transform(new_data)
71+
72+
# Assert there is only 1 one hot vector
73+
assert np.array_equal(new_data, np.ones((len(data), 1)))
74+
75+
def test_df_categorical(self):
76+
"""Test transforming a dataframe containing only categorical values."""
77+
# Setup
78+
data = pd.DataFrame({'cat': np.random.choice(['a', 'b', 'c'], size=1000)})
79+
transformer = DataTransformer()
80+
81+
# Run
82+
transformer.fit(data, ['cat'])
83+
new_data = transformer.transform(data)
84+
transformer.inverse_transform(new_data)
85+
86+
# Assert there are 3 one hot vectors
87+
assert new_data.shape[0] == 1000
88+
assert new_data.shape[1] == 3
89+
assert np.isin(new_data[:, 1:], [0, 1]).all()
90+
91+
def test_df_mixed(self):
92+
"""Test transforming a dataframe containing mixed data types."""
93+
# Setup
94+
data = pd.DataFrame({
95+
'num': np.random.normal(size=1000),
96+
'cat': np.random.choice(['a', 'b', 'c'], size=1000)
97+
})
98+
transformer = DataTransformer()
99+
100+
# Run
101+
transformer.fit(data, ['cat'])
102+
new_data = transformer.transform(data)
103+
transformer.inverse_transform(new_data)
104+
105+
# Assert transformed numerical values are between -1 and 1
106+
assert (new_data[:, 0] > -np.ones(len(new_data))).all()
107+
assert (new_data[:, 0] < np.ones(len(new_data))).all()
108+
109+
# Assert transformed numerical values are a gaussian centered in 0 and with std = 1/4
110+
assert -.1 < np.mean(new_data[:, 0]) < .1
111+
assert .2 < np.std(new_data[:, 0]) < .3
112+
113+
# Assert there are at most `max_columns=10` one hot columns for the numerical values
114+
# and 3 for the categorical ones
115+
assert new_data.shape[0] == 1000
116+
assert 5 <= new_data.shape[1] <= 17
117+
assert np.isin(new_data[:, 1:], [0, 1]).all()
118+
119+
def test_numpy(self):
120+
"""Test transforming a numpy array."""
121+
# Setup
122+
data = pd.DataFrame({
123+
'num': np.random.normal(size=1000),
124+
'cat': np.random.choice(['a', 'b', 'c'], size=1000)
125+
})
126+
data = np.array(data)
127+
transformer = DataTransformer()
128+
129+
# Run
130+
transformer.fit(data, [1])
131+
new_data = transformer.transform(data)
132+
transformer.inverse_transform(new_data)
133+
134+
# Assert transformed numerical values are between -1 and 1
135+
assert (new_data[:, 0] > -np.ones(len(new_data))).all()
136+
assert (new_data[:, 0] < np.ones(len(new_data))).all()
137+
138+
# Assert transformed numerical values are a gaussian centered in 0 and with std = 1/4
139+
assert -.1 < np.mean(new_data[:, 0]) < .1
140+
assert .2 < np.std(new_data[:, 0]) < .3
141+
142+
# Assert there are at most `max_columns=10` one hot columns for the numerical values
143+
# and 3 for the categorical ones
144+
assert new_data.shape[0] == 1000
145+
assert 5 <= new_data.shape[1] <= 17
146+
assert np.isin(new_data[:, 1:], [0, 1]).all()

tests/unit/synthesizer/test_ctgan.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -295,49 +295,3 @@ def test__validate_discrete_columns(self):
295295
ctgan = CTGAN(epochs=1)
296296
with pytest.raises(ValueError, match=r'Invalid columns found: {\'doesnt exist\'}'):
297297
ctgan.fit(data, discrete_columns)
298-
299-
def test_sample(self):
300-
"""Test `sample` correctly sets `condition_info` and `global_condition_vec`.
301-
302-
Tests the first 7 lines of sample by mocking the DataTransformer and DataSampler
303-
and checking that they are being correctly used.
304-
305-
Setup:
306-
- Create and fit the synthesizer
307-
- Mock DataTransformer, DataSampler
308-
309-
Input:
310-
- n = integer
311-
- condition_column = string (not None)
312-
- condition_value = string (not None)
313-
314-
Output:
315-
Not relevant
316-
317-
Note:
318-
- I'm not sure we need this test
319-
"""
320-
321-
def test_set_device(self):
322-
"""Test 'set_device' if a GPU is available.
323-
324-
Check that decoder/encoder can successfully be moved to the device.
325-
If the machine doesn't have a GPU, this test shouldn't run.
326-
327-
Setup:
328-
- Move decoder/encoder to device
329-
330-
Input:
331-
- device = string
332-
333-
Output:
334-
None
335-
336-
Side Effects:
337-
- Set `self._device` to `device`
338-
- Moves `self.decoder` to `self._device`
339-
340-
Note:
341-
- Need to be careful when checking whether the encoder is actually set
342-
to the right device, since it's not saved (it's only used in fit).
343-
"""

0 commit comments

Comments
 (0)