Skip to content

Commit aaa5553

Browse files
committed
add conditional gan
1 parent 2316d5a commit aaa5553

File tree

3 files changed

+51
-50
lines changed

3 files changed

+51
-50
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@
6262
* Deep Belief Network (DBN) `deep_belief_network.py`
6363
* Variational autoencoder (VAE) `variational_autoencoder.py`
6464
* Generative Adversarial Network (GAN) `generative_adversarial_network.py`
65-
* Vanilla GAN
6665
* Deep Convolutional GAN (DCGAN)
67-
* discriminator vs generator
66+
* Conditional GAN
6867
* Transfer Learning `transfer_learning.py`
6968
* CNN on MNIST - freeze convolutional and fine tune dense layers
7069
* Layers `nn_layers.py` / `simple_cnn_layers.py`

dc_gan.png

-36.8 KB
Loading

generative_adversarial_network.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -49,47 +49,27 @@ def backward(self):
4949

5050
class GAN(object):
5151

52-
def __init__(self):
53-
self.n_epochs, self.batch_size = 3, 64
52+
def __init__(self, conditioned=True):
53+
self.n_epochs, self.batch_size = 1, 64
5454
self.gen_input = 100
55+
self.n_classes = 10
56+
self.conditioned = conditioned
5557
self.dc_gan()
5658

57-
def vanilla_gan(self):
58-
gen_lr, dis_lr = 2e-3, 5e-4
59-
self.generator = NN([
60-
FullyConnect([self.gen_input], [256], lr=gen_lr),
61-
BatchNormalization([256], lr=gen_lr),
62-
Activation(act_type='ReLU'),
63-
FullyConnect([256], [512], lr=gen_lr),
64-
BatchNormalization([512], lr=gen_lr),
65-
Activation(act_type='ReLU'),
66-
FullyConnect([512], [1024], lr=gen_lr),
67-
BatchNormalization([1024], lr=gen_lr),
68-
Activation(act_type='ReLU'),
69-
FullyConnect([1024], [1, 28, 28], lr=gen_lr),
70-
Activation(act_type='Tanh')
71-
])
72-
self.discriminator = NN([
73-
FullyConnect([1, 28, 28], [1024], lr=dis_lr),
74-
Activation(act_type='LeakyReLU'),
75-
FullyConnect([1024], [512], lr=dis_lr),
76-
Activation(act_type='LeakyReLU'),
77-
FullyConnect([512], [256], lr=dis_lr),
78-
Activation(act_type='LeakyReLU'),
79-
FullyConnect([256], [1], lr=dis_lr),
80-
Activation(act_type='Sigmoid')
81-
])
82-
8359
def dc_gan(self):
84-
gen_lr, dis_lr = 2e-3, 1e-3
85-
tconv1 = TrasposedConv((128, 7, 7), k_size=4,
60+
gen_lr, dis_lr = 4e-3, 1e-3
61+
dense = FullyConnect(
62+
[self.gen_input + self.n_classes if self.conditioned else self.gen_input],
63+
(128, 7, 7), lr=gen_lr, optimizer='RMSProp'
64+
)
65+
tconv1 = TrasposedConv(dense.out_shape, k_size=4,
8666
k_num=128, stride=2, padding=1, lr=gen_lr, optimizer='RMSProp')
8767
tconv2 = TrasposedConv(tconv1.out_shape, k_size=4,
8868
k_num=128, stride=2, padding=1, lr=gen_lr, optimizer='RMSProp')
8969
tconv3 = TrasposedConv(tconv2.out_shape, k_size=7,
9070
k_num=1, stride=1, padding=3, lr=gen_lr, optimizer='RMSProp')
9171
self.generator = NN([
92-
FullyConnect([self.gen_input], tconv1.in_shape, lr=gen_lr, optimizer='RMSProp'),
72+
dense,
9373
BatchNormalization(tconv1.in_shape, lr=gen_lr, optimizer='RMSProp'),
9474
Activation(act_type='ReLU'),
9575
tconv1,
@@ -102,8 +82,10 @@ def dc_gan(self):
10282
BatchNormalization(tconv3.out_shape, lr=gen_lr, optimizer='RMSProp'),
10383
Activation(act_type='Tanh')
10484
])
105-
conv1 = Conv((1, 28, 28), k_size=7, k_num=128,
106-
stride=1, padding=3, lr=dis_lr, optimizer='RMSProp')
85+
conv1 = Conv(
86+
(1 + self.n_classes if self.conditioned else 1, 28, 28),
87+
k_size=7, k_num=128, stride=1, padding=3, lr=dis_lr, optimizer='RMSProp'
88+
)
10789
conv2 = Conv(conv1.out_shape, k_size=4, k_num=128,
10890
stride=2, padding=1, lr=dis_lr, optimizer='RMSProp')
10991
conv3 = Conv(conv2.out_shape, k_size=4, k_num=128,
@@ -121,47 +103,67 @@ def dc_gan(self):
121103
Activation(act_type='Sigmoid')
122104
])
123105

124-
def fit(self, x):
106+
def fit(self, x, labels):
125107
y_true = np.ones((self.batch_size, 1))
126108
y_false = np.zeros((self.batch_size, 1))
127109
y_dis = np.concatenate([y_true, y_false], axis=0)
128-
generated_img = []
110+
label_channels = np.repeat(labels, 28*28, axis=1).reshape(labels.shape[0], self.n_classes, 28, 28)
129111

130112
for epoch in range(self.n_epochs):
131113
permut = np.random.permutation(
132114
x.shape[0] // self.batch_size * self.batch_size).reshape([-1, self.batch_size])
133115
for b_idx in range(permut.shape[0]):
134-
x_true = x[permut[b_idx, :]]
116+
batch_label_channels = label_channels[permut[b_idx, :]]
117+
if self.conditioned:
118+
x_true = np.concatenate((x[permut[b_idx, :]], batch_label_channels), axis=1)
119+
else:
120+
x_true = x[permut[b_idx, :]]
135121
pred_dis_true = self.discriminator.forward(x_true)
136122
self.discriminator.gradient(bce_grad(pred_dis_true, y_true))
137123
self.discriminator.backward()
138-
139-
x_gen = self.generator.forward(
140-
noise(self.batch_size, self.gen_input))
124+
125+
if self.conditioned:
126+
x_gen = self.generator.forward(
127+
np.concatenate((noise(self.batch_size, self.gen_input), labels[permut[b_idx, :]]), axis=1)
128+
)
129+
x_gen = np.concatenate((x_gen, batch_label_channels), axis=1)
130+
else:
131+
x_gen = self.generator.forward(noise(self.batch_size, self.gen_input))
141132
pred_dis_gen = self.discriminator.forward(x_gen)
142133
self.discriminator.gradient(bce_grad(pred_dis_gen, y_false))
143134
self.discriminator.backward()
144135

145136
pred_gen = self.discriminator.forward(x_gen)
146137
grad = self.discriminator.gradient(bce_grad(pred_gen, y_true))
147-
self.generator.gradient(grad)
138+
if self.conditioned:
139+
self.generator.gradient(grad[:,:1,:,:])
140+
else:
141+
self.generator.gradient(grad)
148142
self.generator.backward()
149143
print(
150144
f'Epoch {epoch} batch {b_idx} discriminator:',
151-
bce_loss(np.concatenate(
152-
[pred_dis_true, pred_dis_gen], axis=0), y_dis),
145+
bce_loss(np.concatenate((pred_dis_true, pred_dis_gen)), y_dis),
153146
'generator:', bce_loss(pred_gen, y_true)
154147
)
155-
generated_img.append(
156-
self.generator.predict(noise(10, self.gen_input)))
157-
return generated_img
158148

159149

160150
def main():
161-
x, _ = fetch_openml('mnist_784', return_X_y=True, data_home='data', as_frame=False)
151+
x, y = fetch_openml('mnist_784', return_X_y=True, data_home='data', as_frame=False)
162152
x = 2 * (x / x.max()) - 1
163-
gan = GAN()
164-
images = gan.fit(x.reshape((-1, 1, 28, 28)))
153+
labels = np.zeros((y.shape[0], 10))
154+
labels[range(y.shape[0]), y.astype(np.int_)] = 1
155+
gan = GAN(conditioned=True)
156+
gan.fit(x.reshape((-1, 1, 28, 28)), labels)
157+
158+
if gan.conditioned:
159+
onehot = np.zeros((30, 10))
160+
onehot[range(30), np.arange(30)%10] = 1
161+
images = gan.generator.predict(
162+
np.concatenate((noise(30, gan.gen_input), onehot), axis=1)
163+
)
164+
else:
165+
images = gan.generator.predict(noise(30, gan.gen_input))
166+
165167
for i, img in enumerate(np.array(images).reshape(-1, 784)):
166168
plt.subplot(len(images), 10, i + 1)
167169
plt.imshow(img.reshape(28, 28), cmap='gray', vmin=-1, vmax=1)

0 commit comments

Comments
 (0)