Skip to content

Commit c5d6be1

Browse files
MNIST normalization. Black refactoring.
1 parent 8e4b612 commit c5d6be1

File tree

20 files changed

+801
-667
lines changed

20 files changed

+801
-667
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
.DS_Store
66

77
data/*/
8+
implementations/*/data
89
implementations/*/images
910
implementations/*/saved_models
1011

implementations/aae/aae.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,34 @@
1515
import torch.nn.functional as F
1616
import torch
1717

18-
os.makedirs('images', exist_ok=True)
18+
os.makedirs("images", exist_ok=True)
1919

2020
parser = argparse.ArgumentParser()
21-
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
22-
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
23-
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
24-
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
25-
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
26-
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
27-
parser.add_argument('--latent_dim', type=int, default=10, help='dimensionality of the latent code')
28-
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
29-
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
30-
parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')
21+
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
22+
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
23+
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
24+
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
25+
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
26+
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
27+
parser.add_argument("--latent_dim", type=int, default=10, help="dimensionality of the latent code")
28+
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
29+
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
30+
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
3131
opt = parser.parse_args()
3232
print(opt)
3333

3434
img_shape = (opt.channels, opt.img_size, opt.img_size)
3535

3636
cuda = True if torch.cuda.is_available() else False
3737

38+
3839
def reparameterization(mu, logvar):
3940
std = torch.exp(logvar / 2)
4041
sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim))))
4142
z = sampled_z * std + mu
4243
return z
4344

45+
4446
class Encoder(nn.Module):
4547
def __init__(self):
4648
super(Encoder, self).__init__()
@@ -50,7 +52,7 @@ def __init__(self):
5052
nn.LeakyReLU(0.2, inplace=True),
5153
nn.Linear(512, 512),
5254
nn.BatchNorm1d(512),
53-
nn.LeakyReLU(0.2, inplace=True)
55+
nn.LeakyReLU(0.2, inplace=True),
5456
)
5557

5658
self.mu = nn.Linear(512, opt.latent_dim)
@@ -64,6 +66,7 @@ def forward(self, img):
6466
z = reparameterization(mu, logvar)
6567
return z
6668

69+
6770
class Decoder(nn.Module):
6871
def __init__(self):
6972
super(Decoder, self).__init__()
@@ -75,14 +78,15 @@ def __init__(self):
7578
nn.BatchNorm1d(512),
7679
nn.LeakyReLU(0.2, inplace=True),
7780
nn.Linear(512, int(np.prod(img_shape))),
78-
nn.Tanh()
81+
nn.Tanh(),
7982
)
8083

8184
def forward(self, z):
8285
img_flat = self.model(z)
8386
img = img_flat.view(img_flat.shape[0], *img_shape)
8487
return img
8588

89+
8690
class Discriminator(nn.Module):
8791
def __init__(self):
8892
super(Discriminator, self).__init__()
@@ -93,13 +97,14 @@ def __init__(self):
9397
nn.Linear(512, 256),
9498
nn.LeakyReLU(0.2, inplace=True),
9599
nn.Linear(256, 1),
96-
nn.Sigmoid()
100+
nn.Sigmoid(),
97101
)
98102

99103
def forward(self, z):
100104
validity = self.model(z)
101105
return validity
102106

107+
103108
# Use binary cross-entropy loss
104109
adversarial_loss = torch.nn.BCELoss()
105110
pixelwise_loss = torch.nn.L1Loss()
@@ -117,29 +122,36 @@ def forward(self, z):
117122
pixelwise_loss.cuda()
118123

119124
# Configure data loader
120-
os.makedirs('../../data/mnist', exist_ok=True)
125+
os.makedirs("../../data/mnist", exist_ok=True)
121126
dataloader = torch.utils.data.DataLoader(
122-
datasets.MNIST('../../data/mnist', train=True, download=True,
123-
transform=transforms.Compose([
124-
transforms.Resize(opt.img_size),
125-
transforms.ToTensor(),
126-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
127-
])),
128-
batch_size=opt.batch_size, shuffle=True)
127+
datasets.MNIST(
128+
"../../data/mnist",
129+
train=True,
130+
download=True,
131+
transform=transforms.Compose(
132+
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
133+
),
134+
),
135+
batch_size=opt.batch_size,
136+
shuffle=True,
137+
)
129138

130139
# Optimizers
131-
optimizer_G = torch.optim.Adam( itertools.chain(encoder.parameters(), decoder.parameters()),
132-
lr=opt.lr, betas=(opt.b1, opt.b2))
140+
optimizer_G = torch.optim.Adam(
141+
itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
142+
)
133143
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
134144

135145
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
136146

147+
137148
def sample_image(n_row, batches_done):
138149
"""Saves a grid of generated digits"""
139150
# Sample noise
140-
z = Variable(Tensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
151+
z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
141152
gen_imgs = decoder(z)
142-
save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
153+
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
154+
143155

144156
# ----------
145157
# Training
@@ -165,8 +177,9 @@ def sample_image(n_row, batches_done):
165177
decoded_imgs = decoder(encoded_imgs)
166178

167179
# Loss measures generator's ability to fool the discriminator
168-
g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + \
169-
0.999 * pixelwise_loss(decoded_imgs, real_imgs)
180+
g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(
181+
decoded_imgs, real_imgs
182+
)
170183

171184
g_loss.backward()
172185
optimizer_G.step()
@@ -188,8 +201,10 @@ def sample_image(n_row, batches_done):
188201
d_loss.backward()
189202
optimizer_D.step()
190203

191-
print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
192-
d_loss.item(), g_loss.item()))
204+
print(
205+
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
206+
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
207+
)
193208

194209
batches_done = epoch * len(dataloader) + i
195210
if batches_done % opt.sample_interval == 0:

implementations/acgan/acgan.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,43 @@
1414
import torch.nn.functional as F
1515
import torch
1616

17-
os.makedirs('images', exist_ok=True)
17+
os.makedirs("images", exist_ok=True)
1818

1919
parser = argparse.ArgumentParser()
20-
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
21-
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
22-
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
23-
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
24-
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
25-
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
26-
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
27-
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
28-
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
29-
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
30-
parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')
20+
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
21+
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
22+
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
23+
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
24+
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
25+
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
26+
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
27+
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
28+
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
29+
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
30+
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
3131
opt = parser.parse_args()
3232
print(opt)
3333

3434
cuda = True if torch.cuda.is_available() else False
3535

36+
3637
def weights_init_normal(m):
3738
classname = m.__class__.__name__
38-
if classname.find('Conv') != -1:
39+
if classname.find("Conv") != -1:
3940
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
40-
elif classname.find('BatchNorm2d') != -1:
41+
elif classname.find("BatchNorm2d") != -1:
4142
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
4243
torch.nn.init.constant_(m.bias.data, 0.0)
4344

45+
4446
class Generator(nn.Module):
4547
def __init__(self):
4648
super(Generator, self).__init__()
4749

4850
self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)
4951

50-
self.init_size = opt.img_size // 4 # Initial size before upsampling
51-
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2))
52+
self.init_size = opt.img_size // 4 # Initial size before upsampling
53+
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
5254

5355
self.conv_blocks = nn.Sequential(
5456
nn.BatchNorm2d(128),
@@ -61,7 +63,7 @@ def __init__(self):
6163
nn.BatchNorm2d(64, 0.8),
6264
nn.LeakyReLU(0.2, inplace=True),
6365
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
64-
nn.Tanh()
66+
nn.Tanh(),
6567
)
6668

6769
def forward(self, noise, labels):
@@ -71,15 +73,14 @@ def forward(self, noise, labels):
7173
img = self.conv_blocks(out)
7274
return img
7375

76+
7477
class Discriminator(nn.Module):
7578
def __init__(self):
7679
super(Discriminator, self).__init__()
7780

7881
def discriminator_block(in_filters, out_filters, bn=True):
7982
"""Returns layers of each discriminator block"""
80-
block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1),
81-
nn.LeakyReLU(0.2, inplace=True),
82-
nn.Dropout2d(0.25)]
83+
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
8384
if bn:
8485
block.append(nn.BatchNorm2d(out_filters, 0.8))
8586
return block
@@ -92,13 +93,11 @@ def discriminator_block(in_filters, out_filters, bn=True):
9293
)
9394

9495
# The height and width of downsampled image
95-
ds_size = opt.img_size // 2**4
96+
ds_size = opt.img_size // 2 ** 4
9697

9798
# Output layers
98-
self.adv_layer = nn.Sequential( nn.Linear(128*ds_size**2, 1),
99-
nn.Sigmoid())
100-
self.aux_layer = nn.Sequential( nn.Linear(128*ds_size**2, opt.n_classes),
101-
nn.Softmax())
99+
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
100+
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())
102101

103102
def forward(self, img):
104103
out = self.conv_blocks(img)
@@ -108,6 +107,7 @@ def forward(self, img):
108107

109108
return validity, label
110109

110+
111111
# Loss functions
112112
adversarial_loss = torch.nn.BCELoss()
113113
auxiliary_loss = torch.nn.CrossEntropyLoss()
@@ -127,15 +127,19 @@ def forward(self, img):
127127
discriminator.apply(weights_init_normal)
128128

129129
# Configure data loader
130-
os.makedirs('../../data/mnist', exist_ok=True)
130+
os.makedirs("../../data/mnist", exist_ok=True)
131131
dataloader = torch.utils.data.DataLoader(
132-
datasets.MNIST('../../data/mnist', train=True, download=True,
133-
transform=transforms.Compose([
134-
transforms.Resize(opt.img_size),
135-
transforms.ToTensor(),
136-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
137-
])),
138-
batch_size=opt.batch_size, shuffle=True)
132+
datasets.MNIST(
133+
"../../data/mnist",
134+
train=True,
135+
download=True,
136+
transform=transforms.Compose(
137+
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
138+
),
139+
),
140+
batch_size=opt.batch_size,
141+
shuffle=True,
142+
)
139143

140144
# Optimizers
141145
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
@@ -144,15 +148,17 @@ def forward(self, img):
144148
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
145149
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
146150

151+
147152
def sample_image(n_row, batches_done):
148153
"""Saves a grid of generated digits ranging from 0 to n_classes"""
149154
# Sample noise
150-
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
155+
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
151156
# Get labels ranging from 0 to n_classes for n rows
152157
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
153158
labels = Variable(LongTensor(labels))
154159
gen_imgs = generator(z, labels)
155-
save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
160+
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
161+
156162

157163
# ----------
158164
# Training
@@ -186,8 +192,7 @@ def sample_image(n_row, batches_done):
186192

187193
# Loss measures generator's ability to fool the discriminator
188194
validity, pred_label = discriminator(gen_imgs)
189-
g_loss = 0.5 * (adversarial_loss(validity, valid) + \
190-
auxiliary_loss(pred_label, gen_labels))
195+
g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
191196

192197
g_loss.backward()
193198
optimizer_G.step()
@@ -200,13 +205,11 @@ def sample_image(n_row, batches_done):
200205

201206
# Loss for real images
202207
real_pred, real_aux = discriminator(real_imgs)
203-
d_real_loss = (adversarial_loss(real_pred, valid) + \
204-
auxiliary_loss(real_aux, labels)) / 2
208+
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
205209

206210
# Loss for fake images
207211
fake_pred, fake_aux = discriminator(gen_imgs.detach())
208-
d_fake_loss = (adversarial_loss(fake_pred, fake) + \
209-
auxiliary_loss(fake_aux, gen_labels)) / 2
212+
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2
210213

211214
# Total discriminator loss
212215
d_loss = (d_real_loss + d_fake_loss) / 2
@@ -219,9 +222,10 @@ def sample_image(n_row, batches_done):
219222
d_loss.backward()
220223
optimizer_D.step()
221224

222-
print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
223-
d_loss.item(), 100 * d_acc,
224-
g_loss.item()))
225+
print(
226+
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
227+
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
228+
)
225229
batches_done = epoch * len(dataloader) + i
226230
if batches_done % opt.sample_interval == 0:
227231
sample_image(n_row=10, batches_done=batches_done)

0 commit comments

Comments
 (0)