@@ -190,6 +190,7 @@ def __init__(
190
190
self ._transformer = None
191
191
self ._data_sampler = None
192
192
self ._generator = None
193
+ self ._discriminator = None
193
194
self .loss_values = None
194
195
195
196
@staticmethod
@@ -330,7 +331,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
330
331
self ._embedding_dim + self ._data_sampler .dim_cond_vec (), self ._generator_dim , data_dim
331
332
).to (self ._device )
332
333
333
- discriminator = Discriminator (
334
+ self . _discriminator = Discriminator (
334
335
data_dim + self ._data_sampler .dim_cond_vec (), self ._discriminator_dim , pac = self .pac
335
336
).to (self ._device )
336
337
@@ -342,7 +343,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
342
343
)
343
344
344
345
optimizerD = optim .Adam (
345
- discriminator .parameters (),
346
+ self . _discriminator .parameters (),
346
347
lr = self ._discriminator_lr ,
347
348
betas = (0.5 , 0.9 ),
348
349
weight_decay = self ._discriminator_decay ,
@@ -395,10 +396,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
395
396
real_cat = real
396
397
fake_cat = fakeact
397
398
398
- y_fake = discriminator (fake_cat )
399
- y_real = discriminator (real_cat )
399
+ y_fake = self . _discriminator (fake_cat )
400
+ y_real = self . _discriminator (real_cat )
400
401
401
- pen = discriminator .calc_gradient_penalty (
402
+ pen = self . _discriminator .calc_gradient_penalty (
402
403
real_cat , fake_cat , self ._device , self .pac
403
404
)
404
405
loss_d = - (torch .mean (y_real ) - torch .mean (y_fake ))
@@ -423,9 +424,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
423
424
fakeact = self ._apply_activate (fake )
424
425
425
426
if c1 is not None :
426
- y_fake = discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
427
+ y_fake = self . _discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
427
428
else :
428
- y_fake = discriminator (fakeact )
429
+ y_fake = self . _discriminator (fakeact )
429
430
430
431
if condvec is None :
431
432
cross_entropy = 0
@@ -520,3 +521,5 @@ def set_device(self, device):
520
521
self ._device = device
521
522
if self ._generator is not None :
522
523
self ._generator .to (self ._device )
524
+ if self ._discriminator is not None :
525
+ self ._discriminator .to (self ._device )
0 commit comments