Skip to content

Commit f85f8b2

Browse files
committed
Added old updated files from seperate account
1 parent 88ee545 commit f85f8b2

11 files changed

+312
-251
lines changed

README.md

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,34 @@
11
# pytorch-GAN
2-
PyTorch implementation of the original GAN paper by Goodfellow et al.
2+
A PyTorch Implementation of Goodfellow et al.'s Paper on Generative Adversarial Networks. Find the paper at: https://arxiv.org/pdf/1406.2661.pdf
3+
4+
## How to run:
5+
Currently has MNIST experiment implemented. Built with torch 1.1.0 and python3.6.
6+
7+
`pip install -r requirements.txt`
8+
9+
`python train.py --epochs 300 --lr 1e-4 --batch-size 32`
10+
11+
Once train.py is running one can open a new shell and running tensboard in order to track various metrics and current generated images during training.
12+
13+
`tensorboard --logdir=runs/<CURRENT_RUN_DIRECTORY>`
14+
15+
### How to adjust hyperparameters:
16+
**One can use different arguments defined in train.py to adjust various hyperparameters**
17+
18+
```
19+
--epochs EPOCHS number of epochs to train for (default: 300)
20+
--lr LR learning rate for optimizer (default: 1e-4)
21+
--batch-size BATCH_SIZE
22+
number of examples in a batch (default: 32)
23+
--device DEVICE device to train on (default: cuda:0 if cuda is
24+
available otherwise cpu)
25+
--latent-size LATENT_SIZE
26+
size of latent space vectors (default: 64)
27+
--g-hidden-size G_HIDDEN_SIZE
28+
number of hidden units per layer in G (default: 256)
29+
--d-hidden-size D_HIDDEN_SIZE
30+
number of hidden units per layer in D (default: 256)
31+
```
32+
33+
## Results:
34+
![Epoch 2](https://i.imgur.com/MbMaKga.png) ![Epoch 20](https://i.imgur.com/W2po4XH.png) ![Epoch 499](https://i.imgur.com/MBs5P0q.png) ![Epoch 999](https://i.imgur.com/gJ2XoPk.png)

criteria/discriminator.py

-27
This file was deleted.

criteria/generator.py

-25
This file was deleted.

data.py

-30
This file was deleted.

generator.py

-42
This file was deleted.

mnist_data.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch as t
2+
from torchvision import datasets, transforms
3+
from torch.utils.data import DataLoader
4+
5+
6+
class GANData:
7+
8+
def __init__(self, args, root='./data',):
9+
self.args = args
10+
self.mnist_dataset = MNISTGANDataset(root=root)
11+
self.real_loader = DataLoader(
12+
self.mnist_dataset,
13+
batch_size=args.batch_size,
14+
shuffle=True
15+
)
16+
17+
def sample_latent_space(self, batch_size=None):
18+
"""
19+
Sample a normal distribution for latent space vectors
20+
(usually denoted by z)
21+
:return: a BATCH SIZE x LATENT SIZE tensor
22+
"""
23+
batch_size = self.args.batch_size if batch_size is None else batch_size
24+
return t.randn(batch_size, self.args.latent_size)
25+
26+
def get_fake_labels(self):
27+
"""
28+
:return: a vector of zeros of length batch size
29+
"""
30+
return t.zeros(self.args.batch_size, 1)
31+
32+
33+
class MNISTGANDataset(datasets.MNIST):
34+
35+
def __init__(self, root):
36+
super(MNISTGANDataset, self).__init__(
37+
root=root,
38+
train=True,
39+
download=True,
40+
transform=transforms.Compose([
41+
transforms.ToTensor(),
42+
transforms.Normalize((0.5,), (0.5,))
43+
])
44+
)
45+
46+
def __getitem__(self, index):
47+
"""
48+
Args:
49+
index (int): Index
50+
51+
Returns:
52+
tuple: (image, target) where target indicates that this
53+
is a real image: 1
54+
"""
55+
56+
# Replace target with ones
57+
img, target = super().__getitem__(index)
58+
return img, t.ones(1)
+14-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import torch as t
21
import torch.nn as nn
3-
import torch.nn.functional as F
42

53

64
class Discriminator(nn.Module):
@@ -11,14 +9,21 @@ class Discriminator(nn.Module):
119
the input belongs to a the real data distribution.
1210
"""
1311

14-
def __init__(self, img_size):
12+
def __init__(self, img_size, hidden_size):
1513
super(Discriminator, self).__init__()
1614

1715
self.img_size = img_size
18-
self.l1 = nn.Linear(img_size, 64)
19-
self.l2 = nn.Linear(64, 128)
20-
self.l3 = nn.Linear(128, 64)
21-
self.l4 = nn.Linear(64, 1)
16+
17+
self.model = nn.Sequential(
18+
nn.Linear(img_size, hidden_size),
19+
nn.LeakyReLU(),
20+
nn.Dropout(),
21+
nn.Linear(hidden_size, hidden_size),
22+
nn.LeakyReLU(),
23+
nn.Dropout(),
24+
nn.Linear(hidden_size, 1),
25+
nn.Sigmoid()
26+
)
2227

2328
def forward(self, x):
2429
"""
@@ -27,14 +32,6 @@ def forward(self, x):
2732
:param x: Image tensor
2833
:return: Float in range [0, 1] - probability score
2934
"""
30-
31-
# Resize x into a vector
35+
# Resize x from a H x W img to a vector
3236
x = x.view(-1, self.img_size)
33-
34-
# Pass through layers with a non-linearity
35-
x = F.relu(self.l1(x))
36-
x = F.relu(self.l2(x))
37-
x = F.relu(self.l3(x))
38-
39-
# Use sigmoid to convert to a probability
40-
return t.sigmoid(self.l4(x))
37+
return self.model(x).clamp(1e-9)

models/generator.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch.nn as nn
2+
3+
4+
class Generator(nn.Module):
5+
"""
6+
Generative Adversarial Network Generator Class
7+
8+
Takes in a latent vector z and returns a vector in
9+
the same image space that the discriminator is trained on.
10+
11+
"""
12+
13+
def __init__(self, img_size, latent_size, hidden_size):
14+
super(Generator, self).__init__()
15+
16+
self.latent_size = latent_size
17+
18+
self.model = nn.Sequential(
19+
nn.Linear(latent_size, hidden_size),
20+
nn.ReLU(),
21+
nn.Dropout(),
22+
nn.Linear(hidden_size, hidden_size),
23+
nn.ReLU(),
24+
nn.Dropout(),
25+
nn.Linear(hidden_size, img_size),
26+
nn.Tanh()
27+
)
28+
29+
def forward(self, z):
30+
"""
31+
Forward pass of a generator
32+
33+
:param z: Latent space vector - size: batch_size x latent_size
34+
:return: Tensor of self.img_size
35+
"""
36+
return self.model(z)
37+

requirements.txt

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
absl-py==0.7.1
2+
certifi==2019.6.16
3+
future==0.17.1
4+
grpcio==1.22.0
5+
Markdown==3.1.1
6+
numpy==1.16.4
7+
Pillow==6.1.0
8+
protobuf==3.9.0
9+
six==1.12.0
10+
tb-nightly==1.15.0a20190720
11+
torch==1.1.0
12+
torchvision==0.3.0
13+
tqdm==4.32.2
14+
Werkzeug==0.15.5

0 commit comments

Comments
 (0)