Skip to content

Commit 7383043

Browse files
committed
initial upload
initial upload for Pytorch implementation of BEGAN
1 parent 56d08fe commit 7383043

File tree

1 file changed

+254
-0
lines changed

1 file changed

+254
-0
lines changed
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import torch
2+
from torch import nn
3+
from torch.autograd import Variable
4+
from torch import optim
5+
import torchvision.datasets as dset
6+
import torchvision.transforms as transforms
7+
import torchvision.utils as vutils
8+
import pickle
9+
10+
# Settings
11+
batchSize = 16
12+
imageSize = 64
13+
z_dim = 64
14+
n_channels = 3
15+
conv_hidden_num = 64
16+
outf="./results"
17+
18+
# Import Dataset
19+
des_dir = "./korCeleb64/"
20+
21+
dataset = dset.ImageFolder(root=des_dir,
22+
transform=transforms.Compose([
23+
transforms.Scale(imageSize),
24+
transforms.ToTensor(),
25+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
26+
]))
27+
28+
dataloader = torch.utils.data.DataLoader(dataset,
29+
batch_size= batchSize,
30+
shuffle=True)
31+
32+
# Design Discriminator in AutoEncoder-manner
33+
class Discriminator(nn.Module):
34+
def __init__(self):
35+
super(Discriminator,self).__init__()
36+
37+
# encoder
38+
self.conv1 = nn.Sequential(
39+
40+
nn.Conv2d(n_channels,conv_hidden_num,3,1,1),
41+
42+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
43+
nn.ELU(True),
44+
nn.Conv2d(conv_hidden_num,2*conv_hidden_num,3,2,1),
45+
nn.ELU(True),
46+
nn.MaxPool2d(2),
47+
48+
nn.Conv2d(2*conv_hidden_num,2*conv_hidden_num,3,1,1),
49+
nn.ELU(True),
50+
nn.Conv2d(2*conv_hidden_num,3*conv_hidden_num,3,2,1),
51+
nn.ELU(True),
52+
nn.MaxPool2d(2),
53+
54+
nn.Conv2d(3*conv_hidden_num,3*conv_hidden_num,3,1,1),
55+
nn.ELU(True),
56+
nn.Conv2d(3*conv_hidden_num,3*conv_hidden_num,3,1,1),
57+
nn.ELU(True)
58+
59+
)
60+
61+
self.fc1 = nn.Linear(4*4*3*conv_hidden_num,z_dim)
62+
63+
# decoder
64+
self.fc2 = nn.Linear(z_dim,16*16*conv_hidden_num)
65+
self.conv2 = nn.Sequential(
66+
67+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
68+
nn.ELU(True),
69+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
70+
nn.ELU(True),
71+
nn.UpsamplingNearest2d(scale_factor=2),
72+
73+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
74+
nn.ELU(True),
75+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
76+
nn.UpsamplingNearest2d(scale_factor=2),
77+
78+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
79+
nn.ELU(True),
80+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
81+
nn.ELU(True),
82+
83+
nn.Conv2d(conv_hidden_num,3,3,1,1)
84+
)
85+
86+
87+
def forward(self,x):
88+
89+
# through encoder conv-layer
90+
conv = self.conv1(x)
91+
# embedding via encoder
92+
embedding = self.fc1(conv.view(-1,3*conv_hidden_num*4*4))
93+
# reconstructing img via decoder
94+
reconst = self.conv2(self.fc2(embedding).view(-1,conv_hidden_num,16,16))
95+
return embedding, reconst
96+
97+
# Design Generator - the same structure with decoder of discriminator
98+
class Generator(nn.Module):
99+
def __init__(self):
100+
super(Generator,self).__init__()
101+
self.fc = nn.Linear(z_dim,16*16*conv_hidden_num)
102+
self.conv = nn.Sequential(
103+
104+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
105+
nn.ELU(True),
106+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
107+
nn.ELU(True),
108+
nn.UpsamplingNearest2d(scale_factor=2),
109+
110+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
111+
nn.ELU(True),
112+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
113+
nn.ELU(True),
114+
nn.UpsamplingNearest2d(scale_factor=2),
115+
116+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
117+
nn.ELU(True),
118+
nn.Conv2d(conv_hidden_num,conv_hidden_num,3,1,1),
119+
nn.ELU(True),
120+
121+
nn.Conv2d(conv_hidden_num,3,3,1,1)
122+
)
123+
124+
def forward(self,x):
125+
x = self.fc(x)
126+
x = x.view(-1,conv_hidden_num,16,16)
127+
out = self.conv(x)
128+
return out
129+
130+
# make instances of network
131+
discriminator = Discriminator()
132+
generator = Generator()
133+
134+
# weight initialization
135+
def weights_init(m):
136+
classname = m.__class__.__name__
137+
if classname.find('Conv') != -1:
138+
m.weight.data.normal_(0.0, 0.02)
139+
elif classname.find('BatchNorm') != -1:
140+
m.weight.data.normal_(1.0, 0.02)
141+
m.bias.data.fill_(0)
142+
143+
discriminator.apply(weights_init)
144+
generator.apply(weights_init)
145+
146+
# activate cuda
147+
discriminator.cuda()
148+
generator.cuda()
149+
150+
# set optimizer and loss criterion
151+
lr = 1e-5 # needs to be reduced in case of modal collapse
152+
D_optimizer = optim.Adam(discriminator.parameters(),lr=lr, betas=(0.5, 0.999))
153+
G_optimizer = optim.Adam(generator.parameters(),lr=lr, betas=(0.5, 0.999))
154+
155+
criterion = nn.L1Loss().cuda()
156+
157+
# lists to track training history
158+
D_losses =[]
159+
G_losses =[]
160+
D_real_losses = []
161+
D_fake_losses = []
162+
measurements = []
163+
k_ts = []
164+
result_dict = {}
165+
166+
# set parameters of BEGAN
167+
k_t = 0
168+
gamma = 0.75
169+
lambda_k = 0.0001
170+
171+
172+
# fixed noise to check out
173+
fixed_noise = Variable(torch.FloatTensor(batchSize*z_dim).uniform_(-1,1)).view(batchSize,z_dim).cuda()
174+
175+
# train
176+
for epoch in range(50000):
177+
178+
fake_ = generator(fixed_noise)
179+
vutils.save_image(fake_.data,"{}/generated_img_{}_epoch.png".format(outf,format(epoch,"0>5")))
180+
181+
for step,(data,_) in enumerate(dataloader):
182+
n_inputs = data.size()[0]
183+
184+
D_optimizer = optim.Adam(discriminator.parameters(),lr=lr, betas=(0.5, 0.999))
185+
G_optimizer = optim.Adam(generator.parameters(),lr=lr, betas=(0.5, 0.999))
186+
187+
# gradient init as zero
188+
discriminator.zero_grad()
189+
generator.zero_grad()
190+
191+
# update optimizer
192+
X_v = Variable(data).cuda()
193+
194+
# put real-image through discriminator
195+
D_real_embedding, D_real_reconst = discriminator(X_v)
196+
D_real_loss = criterion(D_real_reconst,X_v)
197+
198+
# put fake-image through generator
199+
noise = Variable(torch.FloatTensor(n_inputs*z_dim).uniform_(-1,1)).view(n_inputs,z_dim).cuda()
200+
fake = generator(noise)
201+
202+
D_fake_embedding, D_fake_reconst = discriminator(fake.detach())
203+
D_fake_loss_d = criterion(D_fake_reconst,fake.detach())
204+
D_fake_loss_g = criterion(fake,D_fake_reconst.detach())
205+
206+
# calculate loss
207+
D_loss = D_real_loss - k_t*D_fake_loss_d
208+
G_loss = D_fake_loss_g
209+
210+
# backprop & update network
211+
D_loss.backward()
212+
G_loss.backward()
213+
214+
D_optimizer.step()
215+
G_optimizer.step()
216+
217+
# update k_t
218+
k_t += lambda_k*(gamma*D_loss.data[0] - G_loss.data[0])
219+
k_t = max(min(k_t,1),0)
220+
221+
# Calculate Convergence Measurement
222+
M = D_loss + torch.abs(gamma*D_loss - G_loss)
223+
224+
if step%100 == 0:
225+
print('[%d/%d][%d/%d] D_loss: %.4f G_loss: %.4f D_real_loss: %.4f D_fake_loss: %.4f M: %.4f k: %5f'
226+
% (epoch, 10000, step, len(dataloader),
227+
D_loss.data[0], G_loss.data[0], D_real_loss.data[0], D_fake_loss_d.data[0],
228+
M.data[0], k_t))
229+
# save losses and scores
230+
D_losses.append(D_loss.data[0])
231+
G_losses.append(G_loss.data[0])
232+
D_real_losses.append(D_real_loss.data[0])
233+
D_fake_losses.append(D_fake_loss_d.data[0])
234+
measurements.append(M.data[0])
235+
k_ts.append(k_t)
236+
237+
result_dict["D_losses"] = D_losses
238+
result_dict["G_losses"] = G_losses
239+
result_dict["D_real_losses"] = D_real_losses
240+
result_dict["D_fake_losses"] = D_fake_losses
241+
result_dict["measurements"] = measurements
242+
result_dict["k_ts"] = k_ts
243+
244+
pickle.dump(result_dict,open("{}/result_dict.p".format(outf),"wb"))
245+
if epoch<5:
246+
# save fixed img
247+
fake_ = generator(fixed_noise)
248+
vutils.save_image(fake_.data,"{}/generated_img_{}_epoch_{}_step.png".format(outf,format(epoch,"0>5"),step))
249+
if (epoch+1)%100 ==0:
250+
lr *= 0.955
251+
lr = max(lr,1e-7)
252+
# save model
253+
torch.save(discriminator.state_dict(),"D_epoch_{}.pth".format(epoch))
254+
torch.save(generator.state_dict(),"G_epoch_{}.pth".format(epoch))

0 commit comments

Comments
 (0)