14
14
import torch .nn .functional as F
15
15
import torch
16
16
17
- os .makedirs (' images' , exist_ok = True )
17
+ os .makedirs (" images" , exist_ok = True )
18
18
19
19
parser = argparse .ArgumentParser ()
20
- parser .add_argument (' --n_epochs' , type = int , default = 200 , help = ' number of epochs of training' )
21
- parser .add_argument (' --batch_size' , type = int , default = 64 , help = ' size of the batches' )
22
- parser .add_argument (' --lr' , type = float , default = 0.0002 , help = ' adam: learning rate' )
23
- parser .add_argument (' --b1' , type = float , default = 0.5 , help = ' adam: decay of first order momentum of gradient' )
24
- parser .add_argument (' --b2' , type = float , default = 0.999 , help = ' adam: decay of first order momentum of gradient' )
25
- parser .add_argument (' --n_cpu' , type = int , default = 8 , help = ' number of cpu threads to use during batch generation' )
26
- parser .add_argument (' --latent_dim' , type = int , default = 100 , help = ' dimensionality of the latent space' )
27
- parser .add_argument (' --n_classes' , type = int , default = 10 , help = ' number of classes for dataset' )
28
- parser .add_argument (' --img_size' , type = int , default = 32 , help = ' size of each image dimension' )
29
- parser .add_argument (' --channels' , type = int , default = 1 , help = ' number of image channels' )
30
- parser .add_argument (' --sample_interval' , type = int , default = 400 , help = ' interval between image sampling' )
20
+ parser .add_argument (" --n_epochs" , type = int , default = 200 , help = " number of epochs of training" )
21
+ parser .add_argument (" --batch_size" , type = int , default = 64 , help = " size of the batches" )
22
+ parser .add_argument (" --lr" , type = float , default = 0.0002 , help = " adam: learning rate" )
23
+ parser .add_argument (" --b1" , type = float , default = 0.5 , help = " adam: decay of first order momentum of gradient" )
24
+ parser .add_argument (" --b2" , type = float , default = 0.999 , help = " adam: decay of first order momentum of gradient" )
25
+ parser .add_argument (" --n_cpu" , type = int , default = 8 , help = " number of cpu threads to use during batch generation" )
26
+ parser .add_argument (" --latent_dim" , type = int , default = 100 , help = " dimensionality of the latent space" )
27
+ parser .add_argument (" --n_classes" , type = int , default = 10 , help = " number of classes for dataset" )
28
+ parser .add_argument (" --img_size" , type = int , default = 32 , help = " size of each image dimension" )
29
+ parser .add_argument (" --channels" , type = int , default = 1 , help = " number of image channels" )
30
+ parser .add_argument (" --sample_interval" , type = int , default = 400 , help = " interval between image sampling" )
31
31
opt = parser .parse_args ()
32
32
print (opt )
33
33
34
34
cuda = True if torch .cuda .is_available () else False
35
35
36
+
36
37
def weights_init_normal (m ):
37
38
classname = m .__class__ .__name__
38
- if classname .find (' Conv' ) != - 1 :
39
+ if classname .find (" Conv" ) != - 1 :
39
40
torch .nn .init .normal_ (m .weight .data , 0.0 , 0.02 )
40
- elif classname .find (' BatchNorm2d' ) != - 1 :
41
+ elif classname .find (" BatchNorm2d" ) != - 1 :
41
42
torch .nn .init .normal_ (m .weight .data , 1.0 , 0.02 )
42
43
torch .nn .init .constant_ (m .bias .data , 0.0 )
43
44
45
+
44
46
class Generator (nn .Module ):
45
47
def __init__ (self ):
46
48
super (Generator , self ).__init__ ()
47
49
48
50
self .label_emb = nn .Embedding (opt .n_classes , opt .latent_dim )
49
51
50
- self .init_size = opt .img_size // 4 # Initial size before upsampling
51
- self .l1 = nn .Sequential (nn .Linear (opt .latent_dim , 128 * self .init_size ** 2 ))
52
+ self .init_size = opt .img_size // 4 # Initial size before upsampling
53
+ self .l1 = nn .Sequential (nn .Linear (opt .latent_dim , 128 * self .init_size ** 2 ))
52
54
53
55
self .conv_blocks = nn .Sequential (
54
56
nn .BatchNorm2d (128 ),
@@ -61,7 +63,7 @@ def __init__(self):
61
63
nn .BatchNorm2d (64 , 0.8 ),
62
64
nn .LeakyReLU (0.2 , inplace = True ),
63
65
nn .Conv2d (64 , opt .channels , 3 , stride = 1 , padding = 1 ),
64
- nn .Tanh ()
66
+ nn .Tanh (),
65
67
)
66
68
67
69
def forward (self , noise , labels ):
@@ -71,15 +73,14 @@ def forward(self, noise, labels):
71
73
img = self .conv_blocks (out )
72
74
return img
73
75
76
+
74
77
class Discriminator (nn .Module ):
75
78
def __init__ (self ):
76
79
super (Discriminator , self ).__init__ ()
77
80
78
81
def discriminator_block (in_filters , out_filters , bn = True ):
79
82
"""Returns layers of each discriminator block"""
80
- block = [ nn .Conv2d (in_filters , out_filters , 3 , 2 , 1 ),
81
- nn .LeakyReLU (0.2 , inplace = True ),
82
- nn .Dropout2d (0.25 )]
83
+ block = [nn .Conv2d (in_filters , out_filters , 3 , 2 , 1 ), nn .LeakyReLU (0.2 , inplace = True ), nn .Dropout2d (0.25 )]
83
84
if bn :
84
85
block .append (nn .BatchNorm2d (out_filters , 0.8 ))
85
86
return block
@@ -92,13 +93,11 @@ def discriminator_block(in_filters, out_filters, bn=True):
92
93
)
93
94
94
95
# The height and width of downsampled image
95
- ds_size = opt .img_size // 2 ** 4
96
+ ds_size = opt .img_size // 2 ** 4
96
97
97
98
# Output layers
98
- self .adv_layer = nn .Sequential ( nn .Linear (128 * ds_size ** 2 , 1 ),
99
- nn .Sigmoid ())
100
- self .aux_layer = nn .Sequential ( nn .Linear (128 * ds_size ** 2 , opt .n_classes ),
101
- nn .Softmax ())
99
+ self .adv_layer = nn .Sequential (nn .Linear (128 * ds_size ** 2 , 1 ), nn .Sigmoid ())
100
+ self .aux_layer = nn .Sequential (nn .Linear (128 * ds_size ** 2 , opt .n_classes ), nn .Softmax ())
102
101
103
102
def forward (self , img ):
104
103
out = self .conv_blocks (img )
@@ -108,6 +107,7 @@ def forward(self, img):
108
107
109
108
return validity , label
110
109
110
+
111
111
# Loss functions
112
112
adversarial_loss = torch .nn .BCELoss ()
113
113
auxiliary_loss = torch .nn .CrossEntropyLoss ()
@@ -127,15 +127,19 @@ def forward(self, img):
127
127
discriminator .apply (weights_init_normal )
128
128
129
129
# Configure data loader
130
- os .makedirs (' ../../data/mnist' , exist_ok = True )
130
+ os .makedirs (" ../../data/mnist" , exist_ok = True )
131
131
dataloader = torch .utils .data .DataLoader (
132
- datasets .MNIST ('../../data/mnist' , train = True , download = True ,
133
- transform = transforms .Compose ([
134
- transforms .Resize (opt .img_size ),
135
- transforms .ToTensor (),
136
- transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
137
- ])),
138
- batch_size = opt .batch_size , shuffle = True )
132
+ datasets .MNIST (
133
+ "../../data/mnist" ,
134
+ train = True ,
135
+ download = True ,
136
+ transform = transforms .Compose (
137
+ [transforms .Resize (opt .img_size ), transforms .ToTensor (), transforms .Normalize ([0.5 ], [0.5 ])]
138
+ ),
139
+ ),
140
+ batch_size = opt .batch_size ,
141
+ shuffle = True ,
142
+ )
139
143
140
144
# Optimizers
141
145
optimizer_G = torch .optim .Adam (generator .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
@@ -144,15 +148,17 @@ def forward(self, img):
144
148
FloatTensor = torch .cuda .FloatTensor if cuda else torch .FloatTensor
145
149
LongTensor = torch .cuda .LongTensor if cuda else torch .LongTensor
146
150
151
+
147
152
def sample_image (n_row , batches_done ):
148
153
"""Saves a grid of generated digits ranging from 0 to n_classes"""
149
154
# Sample noise
150
- z = Variable (FloatTensor (np .random .normal (0 , 1 , (n_row ** 2 , opt .latent_dim ))))
155
+ z = Variable (FloatTensor (np .random .normal (0 , 1 , (n_row ** 2 , opt .latent_dim ))))
151
156
# Get labels ranging from 0 to n_classes for n rows
152
157
labels = np .array ([num for _ in range (n_row ) for num in range (n_row )])
153
158
labels = Variable (LongTensor (labels ))
154
159
gen_imgs = generator (z , labels )
155
- save_image (gen_imgs .data , 'images/%d.png' % batches_done , nrow = n_row , normalize = True )
160
+ save_image (gen_imgs .data , "images/%d.png" % batches_done , nrow = n_row , normalize = True )
161
+
156
162
157
163
# ----------
158
164
# Training
@@ -186,8 +192,7 @@ def sample_image(n_row, batches_done):
186
192
187
193
# Loss measures generator's ability to fool the discriminator
188
194
validity , pred_label = discriminator (gen_imgs )
189
- g_loss = 0.5 * (adversarial_loss (validity , valid ) + \
190
- auxiliary_loss (pred_label , gen_labels ))
195
+ g_loss = 0.5 * (adversarial_loss (validity , valid ) + auxiliary_loss (pred_label , gen_labels ))
191
196
192
197
g_loss .backward ()
193
198
optimizer_G .step ()
@@ -200,13 +205,11 @@ def sample_image(n_row, batches_done):
200
205
201
206
# Loss for real images
202
207
real_pred , real_aux = discriminator (real_imgs )
203
- d_real_loss = (adversarial_loss (real_pred , valid ) + \
204
- auxiliary_loss (real_aux , labels )) / 2
208
+ d_real_loss = (adversarial_loss (real_pred , valid ) + auxiliary_loss (real_aux , labels )) / 2
205
209
206
210
# Loss for fake images
207
211
fake_pred , fake_aux = discriminator (gen_imgs .detach ())
208
- d_fake_loss = (adversarial_loss (fake_pred , fake ) + \
209
- auxiliary_loss (fake_aux , gen_labels )) / 2
212
+ d_fake_loss = (adversarial_loss (fake_pred , fake ) + auxiliary_loss (fake_aux , gen_labels )) / 2
210
213
211
214
# Total discriminator loss
212
215
d_loss = (d_real_loss + d_fake_loss ) / 2
@@ -219,9 +222,10 @@ def sample_image(n_row, batches_done):
219
222
d_loss .backward ()
220
223
optimizer_D .step ()
221
224
222
- print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" % (epoch , opt .n_epochs , i , len (dataloader ),
223
- d_loss .item (), 100 * d_acc ,
224
- g_loss .item ()))
225
+ print (
226
+ "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
227
+ % (epoch , opt .n_epochs , i , len (dataloader ), d_loss .item (), 100 * d_acc , g_loss .item ())
228
+ )
225
229
batches_done = epoch * len (dataloader ) + i
226
230
if batches_done % opt .sample_interval == 0 :
227
231
sample_image (n_row = 10 , batches_done = batches_done )
0 commit comments