Skip to content

Commit fc0b3e3

Browse files
committed
fix (ctgan): add discriminator to model attributes
1 parent c0ea824 commit fc0b3e3

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

ctgan/synthesizers/ctgan.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__(
190190
self._transformer = None
191191
self._data_sampler = None
192192
self._generator = None
193+
self._discriminator = None
193194
self.loss_values = None
194195

195196
@staticmethod
@@ -330,7 +331,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
330331
self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
331332
).to(self._device)
332333

333-
discriminator = Discriminator(
334+
self._discriminator = Discriminator(
334335
data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
335336
).to(self._device)
336337

@@ -342,7 +343,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
342343
)
343344

344345
optimizerD = optim.Adam(
345-
discriminator.parameters(),
346+
self._discriminator.parameters(),
346347
lr=self._discriminator_lr,
347348
betas=(0.5, 0.9),
348349
weight_decay=self._discriminator_decay,
@@ -395,10 +396,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
395396
real_cat = real
396397
fake_cat = fakeact
397398

398-
y_fake = discriminator(fake_cat)
399-
y_real = discriminator(real_cat)
399+
y_fake = self._discriminator(fake_cat)
400+
y_real = self._discriminator(real_cat)
400401

401-
pen = discriminator.calc_gradient_penalty(
402+
pen = self._discriminator.calc_gradient_penalty(
402403
real_cat, fake_cat, self._device, self.pac
403404
)
404405
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
@@ -423,9 +424,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
423424
fakeact = self._apply_activate(fake)
424425

425426
if c1 is not None:
426-
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
427+
y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1))
427428
else:
428-
y_fake = discriminator(fakeact)
429+
y_fake = self._discriminator(fakeact)
429430

430431
if condvec is None:
431432
cross_entropy = 0
@@ -520,3 +521,5 @@ def set_device(self, device):
520521
self._device = device
521522
if self._generator is not None:
522523
self._generator.to(self._device)
524+
if self._discriminator is not None:
525+
self._discriminator.to(self._device)

0 commit comments

Comments
 (0)