Skip to content

Commit 88ee545

Browse files
committed
Update to add to a device
1 parent 4d9360e commit 88ee545

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Remove data
22
data/
3+
run/
34

45
# Byte-compiled / optimized / DLL files
56
__pycache__/

train.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import torch.optim as optim
23
from data import GANData
34
from discriminator import Discriminator
@@ -14,6 +15,7 @@
1415
BATCH_SIZE = 16
1516
K = 1
1617
latent_size = 32
18+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1719

1820
# Writer will output to ./runs/ directory by default
1921
writer = SummaryWriter()
@@ -23,8 +25,8 @@
2325
img_size = data.get_img_size()
2426

2527
# Instantiate models
26-
D = Discriminator(img_size)
27-
G = Generator(img_size, latent_size)
28+
D = Discriminator(img_size).to(DEVICE)
29+
G = Generator(img_size, latent_size).to(DEVICE)
2830

2931
# Instantiate criterion for both D and G
3032
D_criterion = DiscriminatorCriterion()
@@ -48,7 +50,8 @@
4850
for k in range(K):
4951
# Sample minibatches from P_data and P_z
5052
data_mb, _ = next(iter(data.trainloader))
51-
z_mb = G.sample_z(BATCH_SIZE)
53+
data_mb = data_mb.to(DEVICE)
54+
z_mb = G.sample_z(BATCH_SIZE).to(DEVICE)
5255

5356
# Clear accumulated gradients
5457
D_optim.zero_grad()
@@ -67,7 +70,7 @@
6770
D_optim.step()
6871

6972
# Update Generator
70-
z_mb = G.sample_z(BATCH_SIZE)
73+
z_mb = G.sample_z(BATCH_SIZE).to(DEVICE)
7174

7275
# Clear accumulated gradients
7376
G_optim.zero_grad()
@@ -97,7 +100,7 @@
97100
writer.add_scalar('discriminator_loss', running_d_loss/d_norm, global_step)
98101
writer.add_scalar('generator_loss', running_g_loss/g_norm, global_step)
99102

100-
z_mb = G.sample_z(8)
103+
z_mb = G.sample_z(8).to(DEVICE)
101104
generated_samples = G(z_mb)
102105

103106
grid = torchvision.utils.make_grid(generated_samples.reshape(8, 1, 28, 28))

0 commit comments

Comments
 (0)