@@ -49,47 +49,27 @@ def backward(self):
49
49
50
50
class GAN (object ):
51
51
52
- def __init__ (self ):
53
- self .n_epochs , self .batch_size = 3 , 64
52
+ def __init__ (self , conditioned = True ):
53
+ self .n_epochs , self .batch_size = 1 , 64
54
54
self .gen_input = 100
55
+ self .n_classes = 10
56
+ self .conditioned = conditioned
55
57
self .dc_gan ()
56
58
57
- def vanilla_gan (self ):
58
- gen_lr , dis_lr = 2e-3 , 5e-4
59
- self .generator = NN ([
60
- FullyConnect ([self .gen_input ], [256 ], lr = gen_lr ),
61
- BatchNormalization ([256 ], lr = gen_lr ),
62
- Activation (act_type = 'ReLU' ),
63
- FullyConnect ([256 ], [512 ], lr = gen_lr ),
64
- BatchNormalization ([512 ], lr = gen_lr ),
65
- Activation (act_type = 'ReLU' ),
66
- FullyConnect ([512 ], [1024 ], lr = gen_lr ),
67
- BatchNormalization ([1024 ], lr = gen_lr ),
68
- Activation (act_type = 'ReLU' ),
69
- FullyConnect ([1024 ], [1 , 28 , 28 ], lr = gen_lr ),
70
- Activation (act_type = 'Tanh' )
71
- ])
72
- self .discriminator = NN ([
73
- FullyConnect ([1 , 28 , 28 ], [1024 ], lr = dis_lr ),
74
- Activation (act_type = 'LeakyReLU' ),
75
- FullyConnect ([1024 ], [512 ], lr = dis_lr ),
76
- Activation (act_type = 'LeakyReLU' ),
77
- FullyConnect ([512 ], [256 ], lr = dis_lr ),
78
- Activation (act_type = 'LeakyReLU' ),
79
- FullyConnect ([256 ], [1 ], lr = dis_lr ),
80
- Activation (act_type = 'Sigmoid' )
81
- ])
82
-
83
59
def dc_gan (self ):
84
- gen_lr , dis_lr = 2e-3 , 1e-3
85
- tconv1 = TrasposedConv ((128 , 7 , 7 ), k_size = 4 ,
60
+ gen_lr , dis_lr = 4e-3 , 1e-3
61
+ dense = FullyConnect (
62
+ [self .gen_input + self .n_classes if self .conditioned else self .gen_input ],
63
+ (128 , 7 , 7 ), lr = gen_lr , optimizer = 'RMSProp'
64
+ )
65
+ tconv1 = TrasposedConv (dense .out_shape , k_size = 4 ,
86
66
k_num = 128 , stride = 2 , padding = 1 , lr = gen_lr , optimizer = 'RMSProp' )
87
67
tconv2 = TrasposedConv (tconv1 .out_shape , k_size = 4 ,
88
68
k_num = 128 , stride = 2 , padding = 1 , lr = gen_lr , optimizer = 'RMSProp' )
89
69
tconv3 = TrasposedConv (tconv2 .out_shape , k_size = 7 ,
90
70
k_num = 1 , stride = 1 , padding = 3 , lr = gen_lr , optimizer = 'RMSProp' )
91
71
self .generator = NN ([
92
- FullyConnect ([ self . gen_input ], tconv1 . in_shape , lr = gen_lr , optimizer = 'RMSProp' ) ,
72
+ dense ,
93
73
BatchNormalization (tconv1 .in_shape , lr = gen_lr , optimizer = 'RMSProp' ),
94
74
Activation (act_type = 'ReLU' ),
95
75
tconv1 ,
@@ -102,8 +82,10 @@ def dc_gan(self):
102
82
BatchNormalization (tconv3 .out_shape , lr = gen_lr , optimizer = 'RMSProp' ),
103
83
Activation (act_type = 'Tanh' )
104
84
])
105
- conv1 = Conv ((1 , 28 , 28 ), k_size = 7 , k_num = 128 ,
106
- stride = 1 , padding = 3 , lr = dis_lr , optimizer = 'RMSProp' )
85
+ conv1 = Conv (
86
+ (1 + self .n_classes if self .conditioned else 1 , 28 , 28 ),
87
+ k_size = 7 , k_num = 128 , stride = 1 , padding = 3 , lr = dis_lr , optimizer = 'RMSProp'
88
+ )
107
89
conv2 = Conv (conv1 .out_shape , k_size = 4 , k_num = 128 ,
108
90
stride = 2 , padding = 1 , lr = dis_lr , optimizer = 'RMSProp' )
109
91
conv3 = Conv (conv2 .out_shape , k_size = 4 , k_num = 128 ,
@@ -121,47 +103,67 @@ def dc_gan(self):
121
103
Activation (act_type = 'Sigmoid' )
122
104
])
123
105
124
- def fit (self , x ):
106
+ def fit (self , x , labels ):
125
107
y_true = np .ones ((self .batch_size , 1 ))
126
108
y_false = np .zeros ((self .batch_size , 1 ))
127
109
y_dis = np .concatenate ([y_true , y_false ], axis = 0 )
128
- generated_img = []
110
+ label_channels = np . repeat ( labels , 28 * 28 , axis = 1 ). reshape ( labels . shape [ 0 ], self . n_classes , 28 , 28 )
129
111
130
112
for epoch in range (self .n_epochs ):
131
113
permut = np .random .permutation (
132
114
x .shape [0 ] // self .batch_size * self .batch_size ).reshape ([- 1 , self .batch_size ])
133
115
for b_idx in range (permut .shape [0 ]):
134
- x_true = x [permut [b_idx , :]]
116
+ batch_label_channels = label_channels [permut [b_idx , :]]
117
+ if self .conditioned :
118
+ x_true = np .concatenate ((x [permut [b_idx , :]], batch_label_channels ), axis = 1 )
119
+ else :
120
+ x_true = x [permut [b_idx , :]]
135
121
pred_dis_true = self .discriminator .forward (x_true )
136
122
self .discriminator .gradient (bce_grad (pred_dis_true , y_true ))
137
123
self .discriminator .backward ()
138
-
139
- x_gen = self .generator .forward (
140
- noise (self .batch_size , self .gen_input ))
124
+
125
+ if self .conditioned :
126
+ x_gen = self .generator .forward (
127
+ np .concatenate ((noise (self .batch_size , self .gen_input ), labels [permut [b_idx , :]]), axis = 1 )
128
+ )
129
+ x_gen = np .concatenate ((x_gen , batch_label_channels ), axis = 1 )
130
+ else :
131
+ x_gen = self .generator .forward (noise (self .batch_size , self .gen_input ))
141
132
pred_dis_gen = self .discriminator .forward (x_gen )
142
133
self .discriminator .gradient (bce_grad (pred_dis_gen , y_false ))
143
134
self .discriminator .backward ()
144
135
145
136
pred_gen = self .discriminator .forward (x_gen )
146
137
grad = self .discriminator .gradient (bce_grad (pred_gen , y_true ))
147
- self .generator .gradient (grad )
138
+ if self .conditioned :
139
+ self .generator .gradient (grad [:,:1 ,:,:])
140
+ else :
141
+ self .generator .gradient (grad )
148
142
self .generator .backward ()
149
143
print (
150
144
f'Epoch { epoch } batch { b_idx } discriminator:' ,
151
- bce_loss (np .concatenate (
152
- [pred_dis_true , pred_dis_gen ], axis = 0 ), y_dis ),
145
+ bce_loss (np .concatenate ((pred_dis_true , pred_dis_gen )), y_dis ),
153
146
'generator:' , bce_loss (pred_gen , y_true )
154
147
)
155
- generated_img .append (
156
- self .generator .predict (noise (10 , self .gen_input )))
157
- return generated_img
158
148
159
149
160
150
def main ():
161
- x , _ = fetch_openml ('mnist_784' , return_X_y = True , data_home = 'data' , as_frame = False )
151
+ x , y = fetch_openml ('mnist_784' , return_X_y = True , data_home = 'data' , as_frame = False )
162
152
x = 2 * (x / x .max ()) - 1
163
- gan = GAN ()
164
- images = gan .fit (x .reshape ((- 1 , 1 , 28 , 28 )))
153
+ labels = np .zeros ((y .shape [0 ], 10 ))
154
+ labels [range (y .shape [0 ]), y .astype (np .int_ )] = 1
155
+ gan = GAN (conditioned = True )
156
+ gan .fit (x .reshape ((- 1 , 1 , 28 , 28 )), labels )
157
+
158
+ if gan .conditioned :
159
+ onehot = np .zeros ((30 , 10 ))
160
+ onehot [range (30 ), np .arange (30 )% 10 ] = 1
161
+ images = gan .generator .predict (
162
+ np .concatenate ((noise (30 , gan .gen_input ), onehot ), axis = 1 )
163
+ )
164
+ else :
165
+ images = gan .generator .predict (noise (30 , gan .gen_input ))
166
+
165
167
for i , img in enumerate (np .array (images ).reshape (- 1 , 784 )):
166
168
plt .subplot (len (images ), 10 , i + 1 )
167
169
plt .imshow (img .reshape (28 , 28 ), cmap = 'gray' , vmin = - 1 , vmax = 1 )
0 commit comments