Skip to content

Commit 927c0d2

Browse files
committed
VAE-MNIST init
VAE-MNIST init
1 parent 7afc4f2 commit 927c0d2

File tree

5 files changed

+395
-0
lines changed

5 files changed

+395
-0
lines changed

.DS_Store

0 Bytes
Binary file not shown.

GAN/.DS_Store

0 Bytes
Binary file not shown.

VAE/.DS_Store

0 Bytes
Binary file not shown.

VAE/구현/.DS_Store

6 KB
Binary file not shown.
Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Import"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 2,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"import torch\n",
17+
"from torch import nn\n",
18+
"import torchvision.transforms as transforms\n",
19+
"import torch.backends.cudnn as cudnn\n",
20+
"import torchvision.datasets as dsets\n",
21+
"import itertools\n",
22+
"from torch.autograd import Variable"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 3,
28+
"metadata": {
29+
"collapsed": true
30+
},
31+
"outputs": [],
32+
"source": [
33+
"batchSize = 100\n",
34+
"z_size = 100\n",
35+
"h_size = 128"
36+
]
37+
},
38+
{
39+
"cell_type": "markdown",
40+
"metadata": {},
41+
"source": [
42+
"# Data Set"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": 4,
48+
"metadata": {},
49+
"outputs": [
50+
{
51+
"name": "stdout",
52+
"output_type": "stream",
53+
"text": [
54+
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
55+
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
56+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
57+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
58+
"Processing...\n",
59+
"Done!\n"
60+
]
61+
}
62+
],
63+
"source": [
64+
"transform = transforms.Compose([ \n",
65+
" transforms.ToTensor()\n",
66+
"])\n",
67+
"\n",
68+
"cudnn.benchmark = True\n",
69+
"\n",
70+
"train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform) \n",
71+
"data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)"
72+
]
73+
},
74+
{
75+
"cell_type": "markdown",
76+
"metadata": {},
77+
"source": [
78+
"# Build Model"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 5,
84+
"metadata": {
85+
"collapsed": true
86+
},
87+
"outputs": [],
88+
"source": [
89+
"class Encoder(nn.Module):\n",
90+
" def __init__(self):\n",
91+
" super(Encoder, self).__init__()\n",
92+
" \n",
93+
" self.fc0 = nn.Sequential(\n",
94+
" \n",
95+
" nn.Linear(28*28,1024),\n",
96+
" nn.BatchNorm1d(1024),\n",
97+
" nn.LeakyReLU(0.1),\n",
98+
" \n",
99+
" nn.Linear(1024,512),\n",
100+
" nn.BatchNorm1d(512),\n",
101+
" nn.LeakyReLU(0.1),\n",
102+
" \n",
103+
" nn.Linear(512,256),\n",
104+
" nn.BatchNorm1d(256),\n",
105+
" nn.LeakyReLU(0.1),\n",
106+
" \n",
107+
" )\n",
108+
" \n",
109+
" self.fc1 = nn.Sequential(\n",
110+
" nn.Linear(256,100),\n",
111+
" nn.LeakyReLU(0.1)\n",
112+
" )\n",
113+
" \n",
114+
" self.fc2 = nn.Sequential(\n",
115+
" nn.Linear(256,100),\n",
116+
" nn.LeakyReLU(0.1)\n",
117+
" )\n",
118+
" \n",
119+
" def forward(self,x):\n",
120+
" x = x.view(batchSize,-1)\n",
121+
" x = self.fc0(x)\n",
122+
" z_mu = self.fc1(x)\n",
123+
" z_log_sigma = self.fc2(x)\n",
124+
" \n",
125+
" return z_mu, z_log_sigma\n",
126+
" \n",
127+
" "
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": 6,
133+
"metadata": {
134+
"collapsed": true
135+
},
136+
"outputs": [],
137+
"source": [
138+
"class Decoder(nn.Module):\n",
139+
" def __init__(self):\n",
140+
" super(Decoder, self).__init__()\n",
141+
"\n",
142+
" self.model = nn.Sequential(\n",
143+
" \n",
144+
" nn.Linear(100,256),\n",
145+
" nn.BatchNorm1d(256),\n",
146+
" nn.LeakyReLU(0.1),\n",
147+
" \n",
148+
" nn.Linear(256,512),\n",
149+
" nn.BatchNorm1d(512),\n",
150+
" nn.LeakyReLU(0.1),\n",
151+
" \n",
152+
" nn.Linear(512,1024),\n",
153+
" nn.BatchNorm1d(1024),\n",
154+
" nn.LeakyReLU(0.1),\n",
155+
" \n",
156+
" nn.Linear(1024,28*28),\n",
157+
" nn.Sigmoid()\n",
158+
" )\n",
159+
" \n",
160+
" def forward(self,x):\n",
161+
" x = self.model(x)\n",
162+
" return x\n",
163+
" \n",
164+
" "
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": 7,
170+
"metadata": {
171+
"collapsed": true
172+
},
173+
"outputs": [],
174+
"source": [
175+
"def sample_z(z_mu, z_log_sigma):\n",
176+
" epsilon = torch.FloatTensor(100*100).normal_(0,1).view((100,100))\n",
177+
" z_samples = z_mu + torch.mul(z_log_sigma, Variable(epsilon))\n",
178+
" \n",
179+
" return z_samples"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": 8,
185+
"metadata": {
186+
"collapsed": true,
187+
"scrolled": true
188+
},
189+
"outputs": [],
190+
"source": [
191+
"encoder = Encoder()\n",
192+
"decoder = Decoder()"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": 12,
198+
"metadata": {
199+
"collapsed": true
200+
},
201+
"outputs": [],
202+
"source": [
203+
"criterion1 = nn.BCELoss()\n",
204+
"criterion2 = nn.KLDivLoss()"
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": 13,
210+
"metadata": {
211+
"collapsed": true
212+
},
213+
"outputs": [],
214+
"source": [
215+
"optimizer = torch.optim.Adam(itertools.chain(encoder.parameters(),decoder.parameters()),\n",
216+
" lr=1e-4,\n",
217+
" betas = (0.5,0.999)\n",
218+
" )\n"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": 14,
224+
"metadata": {
225+
"collapsed": true
226+
},
227+
"outputs": [
228+
{
229+
"data": {
230+
"application/vnd.jupyter.widget-view+json": {
231+
"model_id": "3ed42f131c3e4256a6b0b05d62376dea"
232+
}
233+
},
234+
"metadata": {},
235+
"output_type": "display_data"
236+
},
237+
{
238+
"name": "stdout",
239+
"output_type": "stream",
240+
"text": [
241+
"epoch: 22, step: 1, loss: 154.88479614257812\n",
242+
"\n"
243+
]
244+
},
245+
{
246+
"ename": "KeyboardInterrupt",
247+
"evalue": "",
248+
"output_type": "error",
249+
"traceback": [
250+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
251+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
252+
"\u001b[0;32m<ipython-input-14-65d0b7e5058a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreconstruction_loss\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mKLD\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
253+
"\u001b[0;32m//anaconda/lib/python3.5/site-packages/torch/autograd/variable.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_variables)\u001b[0m\n\u001b[1;32m 144\u001b[0m 'or with gradient w.r.t. the variable')\n\u001b[1;32m 145\u001b[0m \u001b[0mgradient\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresize_as_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 146\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_execution_engine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_backward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_variables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
254+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
255+
]
256+
}
257+
],
258+
"source": [
259+
"niter = 20\n",
260+
"\n",
261+
"for epoch in range(21,100):\n",
262+
" for i, (data,_) in enumerate(tqdm_notebook(data_loader)):\n",
263+
" encoder.zero_grad()\n",
264+
" decoder.zero_grad()\n",
265+
" \n",
266+
" data_v = Variable(data)\n",
267+
" z_mu, z_log_sigma = encoder(data_v)\n",
268+
" z_samples = sample_z(z_mu,z_log_sigma)\n",
269+
" fake = decoder(z_samples)\n",
270+
" \n",
271+
" reconstruction_loss = criterion1(fake,data_v)\n",
272+
" KLD_element = z_mu.pow(2).add_(z_log_sigma.pow(2).exp()).mul_(-1).add_(1).add_(z_log_sigma.pow(2))\n",
273+
" KLD = torch.sum(KLD_element).mul_(-0.5)\n",
274+
" \n",
275+
" loss = reconstruction_loss + KLD\n",
276+
" \n",
277+
" loss.backward()\n",
278+
" \n",
279+
" optimizer.step()\n",
280+
" \n",
281+
" if i%100 == 0:\n",
282+
" print(\"epoch: {}, step: {}, loss: {}\".format(epoch+1,i+1,loss.data[0]))\n",
283+
" \n",
284+
" \n",
285+
" \n",
286+
" \n",
287+
" # 결고ㅏ 이미지 저장\n",
288+
" Z_v = Variable(torch.FloatTensor(8*8*z_size).normal_(0,1).view(64,-1))\n",
289+
"\n",
290+
" samples = decoder(Z_v).data.numpy()\n",
291+
" fig = plt.figure(figsize=(4, 4))\n",
292+
" gs = gridspec.GridSpec(8, 8)\n",
293+
" gs.update(wspace=0.05, hspace=0.05)\n",
294+
" for j, sample in enumerate(samples):\n",
295+
" ax = plt.subplot(gs[j])\n",
296+
" plt.axis('off')\n",
297+
" ax.set_xticklabels([])\n",
298+
" ax.set_yticklabels([])\n",
299+
" ax.set_aspect('equal')\n",
300+
" plt.imshow(sample.reshape(28, 28), cmap='Greys_r')\n",
301+
" fig.savefig(\"test_imgs_{}_{}.png\".format(epoch,i))\n",
302+
"\n",
303+
"\n",
304+
" \n",
305+
" "
306+
]
307+
},
308+
{
309+
"cell_type": "code",
310+
"execution_count": null,
311+
"metadata": {
312+
"collapsed": true
313+
},
314+
"outputs": [],
315+
"source": [
316+
"torch.save(encoder.state_dict(),\"encoder.pth\")\n",
317+
"torch.save(decoder.state_dict(),\"decoder.pth\")"
318+
]
319+
},
320+
{
321+
"cell_type": "code",
322+
"execution_count": 61,
323+
"metadata": {
324+
"collapsed": true
325+
},
326+
"outputs": [],
327+
"source": [
328+
"import matplotlib.pyplot as plt\n",
329+
"import matplotlib.gridspec as gridspec\n",
330+
"% matplotlib inline"
331+
]
332+
},
333+
{
334+
"cell_type": "code",
335+
"execution_count": null,
336+
"metadata": {
337+
"collapsed": true,
338+
"scrolled": false
339+
},
340+
"outputs": [],
341+
"source": [
342+
"Z_v = Variable(torch.FloatTensor(8*8*z_size).normal_(0,1).view(64,-1))\n",
343+
"\n",
344+
"samples = decoder(Z_v).data.numpy()\n",
345+
"fig = plt.figure(figsize=(4, 4))\n",
346+
"gs = gridspec.GridSpec(8, 8)\n",
347+
"gs.update(wspace=0.05, hspace=0.05)\n",
348+
"for j, sample in enumerate(samples):\n",
349+
" ax = plt.subplot(gs[j])\n",
350+
" plt.axis('off')\n",
351+
" ax.set_xticklabels([])\n",
352+
" ax.set_yticklabels([])\n",
353+
" ax.set_aspect('equal')\n",
354+
" plt.imshow(sample.reshape(28, 28), cmap='Greys_r')\n",
355+
"fig.savefig(\"test_imgs_{}_{}.png\".format(epoch,i))\n",
356+
"\n"
357+
]
358+
},
359+
{
360+
"cell_type": "markdown",
361+
"metadata": {},
362+
"source": [
363+
"#### 18th Epoch Image"
364+
]
365+
},
366+
{
367+
"cell_type": "markdown",
368+
"metadata": {},
369+
"source": [
370+
"<a href=\"https://imgur.com/Ri1HUA2\"><img src=\"https://i.imgur.com/Ri1HUA2.png\" title=\"source: imgur.com\" /></a>"
371+
]
372+
}
373+
],
374+
"metadata": {
375+
"kernelspec": {
376+
"display_name": "Python 3",
377+
"language": "python",
378+
"name": "python3"
379+
},
380+
"language_info": {
381+
"codemirror_mode": {
382+
"name": "ipython",
383+
"version": 3
384+
},
385+
"file_extension": ".py",
386+
"mimetype": "text/x-python",
387+
"name": "python",
388+
"nbconvert_exporter": "python",
389+
"pygments_lexer": "ipython3",
390+
"version": "3.6.2"
391+
}
392+
},
393+
"nbformat": 4,
394+
"nbformat_minor": 2
395+
}

0 commit comments

Comments
 (0)