|
| 1 | +import torch |
1 | 2 | import torch.optim as optim
|
2 | 3 | from data import GANData
|
3 | 4 | from discriminator import Discriminator
|
|
14 | 15 | BATCH_SIZE = 16
|
15 | 16 | K = 1
|
16 | 17 | latent_size = 32
|
| 18 | +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
17 | 19 |
|
18 | 20 | # Writer will output to ./runs/ directory by default
|
19 | 21 | writer = SummaryWriter()
|
|
23 | 25 | img_size = data.get_img_size()
|
24 | 26 |
|
25 | 27 | # Instantiate models
|
26 |
| -D = Discriminator(img_size) |
27 |
| -G = Generator(img_size, latent_size) |
| 28 | +D = Discriminator(img_size).to(DEVICE) |
| 29 | +G = Generator(img_size, latent_size).to(DEVICE) |
28 | 30 |
|
29 | 31 | # Instantiate criterion for both D and G
|
30 | 32 | D_criterion = DiscriminatorCriterion()
|
|
48 | 50 | for k in range(K):
|
49 | 51 | # Sample minibatches from P_data and P_z
|
50 | 52 | data_mb, _ = next(iter(data.trainloader))
|
51 |
| - z_mb = G.sample_z(BATCH_SIZE) |
| 53 | + data_mb = data_mb.to(DEVICE) |
| 54 | + z_mb = G.sample_z(BATCH_SIZE).to(DEVICE) |
52 | 55 |
|
53 | 56 | # Clear accumulated gradients
|
54 | 57 | D_optim.zero_grad()
|
|
67 | 70 | D_optim.step()
|
68 | 71 |
|
69 | 72 | # Update Generator
|
70 |
| - z_mb = G.sample_z(BATCH_SIZE) |
| 73 | + z_mb = G.sample_z(BATCH_SIZE).to(DEVICE) |
71 | 74 |
|
72 | 75 | # Clear accumulated gradients
|
73 | 76 | G_optim.zero_grad()
|
|
97 | 100 | writer.add_scalar('discriminator_loss', running_d_loss/d_norm, global_step)
|
98 | 101 | writer.add_scalar('generator_loss', running_g_loss/g_norm, global_step)
|
99 | 102 |
|
100 |
| - z_mb = G.sample_z(8) |
| 103 | + z_mb = G.sample_z(8).to(DEVICE) |
101 | 104 | generated_samples = G(z_mb)
|
102 | 105 |
|
103 | 106 | grid = torchvision.utils.make_grid(generated_samples.reshape(8, 1, 28, 28))
|
|
0 commit comments