1
1
import itertools
2
- import matplotlib as mpl
3
2
import numpy as np
4
3
import os
5
4
import tensorflow as tf
5
+ import tensorflow .keras as tfk
6
6
import tensorflow .contrib .slim as slim
7
7
import time
8
- import seaborn as sns
9
-
10
- from matplotlib import pyplot as plt
8
+ import tensorflow_datasets as tfds
9
+ import tensorflow_probability as tfp
11
10
from imageio import imwrite
12
11
from tensorflow .contrib .learn .python .learn .datasets .mnist import read_data_sets
13
-
14
- sns .set_style ('whitegrid' )
15
-
16
- distributions = tf .distributions
12
+ tfkl = tfk .layers
13
+ tfc = tf .compat .v1
17
14
18
15
flags = tf .app .flags
19
16
flags .DEFINE_string ('data_dir' , '/tmp/dat/' , 'Directory for data' )
20
17
flags .DEFINE_string ('logdir' , '/tmp/log/' , 'Directory for logs' )
21
-
22
- # For making plots:
23
- # flags.DEFINE_integer('latent_dim', 2, 'Latent dimensionality of model')
24
- # flags.DEFINE_integer('batch_size', 64, 'Minibatch size')
25
- # flags.DEFINE_integer('n_samples', 10, 'Number of samples to save')
26
- # flags.DEFINE_integer('print_every', 10, 'Print every n iterations')
27
- # flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks')
28
- # flags.DEFINE_integer('n_iterations', 1000, 'number of iterations')
29
-
30
- # For bigger model:
31
18
flags .DEFINE_integer ('latent_dim' , 100 , 'Latent dimensionality of model' )
32
19
flags .DEFINE_integer ('batch_size' , 64 , 'Minibatch size' )
33
20
flags .DEFINE_integer ('n_samples' , 1 , 'Number of samples to save' )
@@ -50,12 +37,13 @@ def inference_network(x, latent_dim, hidden_size):
50
37
mu: Mean parameters for the variational family Normal
51
38
sigma: Standard deviation parameters for the variational family Normal
52
39
"""
53
- with slim .arg_scope ([slim .fully_connected ], activation_fn = tf .nn .relu ):
54
- net = slim .flatten (x )
55
- net = slim .fully_connected (net , hidden_size )
56
- net = slim .fully_connected (net , hidden_size )
57
- gaussian_params = slim .fully_connected (
58
- net , latent_dim * 2 , activation_fn = None )
40
+ inference_net = tfk .Sequential ([
41
+ tfkl .Flatten (),
42
+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
43
+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
44
+ tfkl .Dense (latent_dim * 2 , activation = None )
45
+ ])
46
+ gaussian_params = inference_net (x )
59
47
# The mean parameter is unconstrained
60
48
mu = gaussian_params [:, :latent_dim ]
61
49
# The standard deviation must be positive. Parametrize with a softplus
@@ -73,174 +61,111 @@ def generative_network(z, hidden_size):
73
61
Returns:
74
62
bernoulli_logits: logits for the Bernoulli likelihood of the data
75
63
"""
76
- with slim .arg_scope ([slim .fully_connected ], activation_fn = tf .nn .relu ):
77
- net = slim .fully_connected (z , hidden_size )
78
- net = slim .fully_connected (net , hidden_size )
79
- bernoulli_logits = slim .fully_connected (net , 784 , activation_fn = None )
80
- bernoulli_logits = tf .reshape (bernoulli_logits , [- 1 , 28 , 28 , 1 ])
81
- return bernoulli_logits
64
+ generative_net = tfk .Sequential ([
65
+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
66
+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
67
+ tfkl .Dense (28 * 28 , activation = None )
68
+ ])
69
+ bernoulli_logits = generative_net (z )
70
+ return tf .reshape (bernoulli_logits , [- 1 , 28 , 28 , 1 ])
82
71
83
72
84
73
def train ():
85
74
# Train a Variational Autoencoder on MNIST
86
75
87
76
# Input placeholders
88
77
with tf .name_scope ('data' ):
89
- x = tf .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
90
- tf .summary .image ('data' , x )
78
+ x = tfc .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
79
+ tfc .summary .image ('data' , x )
91
80
92
- with tf .variable_scope ('variational' ):
81
+ with tfc .variable_scope ('variational' ):
93
82
q_mu , q_sigma = inference_network (x = x ,
94
83
latent_dim = FLAGS .latent_dim ,
95
84
hidden_size = FLAGS .hidden_size )
96
85
# The variational distribution is a Normal with mean and standard
97
86
# deviation given by the inference network
98
- q_z = distributions .Normal (loc = q_mu , scale = q_sigma )
99
- assert q_z .reparameterization_type == distributions .FULLY_REPARAMETERIZED
87
+ q_z = tfp . distributions .Normal (loc = q_mu , scale = q_sigma )
88
+ assert q_z .reparameterization_type == tfp . distributions .FULLY_REPARAMETERIZED
100
89
101
- with tf .variable_scope ('model' ):
90
+ with tfc .variable_scope ('model' ):
102
91
# The likelihood is Bernoulli-distributed with logits given by the
103
92
# generative network
104
93
p_x_given_z_logits = generative_network (z = q_z .sample (),
105
94
hidden_size = FLAGS .hidden_size )
106
- p_x_given_z = distributions .Bernoulli (logits = p_x_given_z_logits )
95
+ p_x_given_z = tfp . distributions .Bernoulli (logits = p_x_given_z_logits )
107
96
posterior_predictive_samples = p_x_given_z .sample ()
108
- tf .summary .image ('posterior_predictive' ,
97
+ tfc .summary .image ('posterior_predictive' ,
109
98
tf .cast (posterior_predictive_samples , tf .float32 ))
110
99
111
100
# Take samples from the prior
112
- with tf .variable_scope ('model' , reuse = True ):
113
- p_z = distributions .Normal (loc = np .zeros (FLAGS .latent_dim , dtype = np .float32 ),
101
+ with tfc .variable_scope ('model' , reuse = True ):
102
+ p_z = tfp . distributions .Normal (loc = np .zeros (FLAGS .latent_dim , dtype = np .float32 ),
114
103
scale = np .ones (FLAGS .latent_dim , dtype = np .float32 ))
115
104
p_z_sample = p_z .sample (FLAGS .n_samples )
116
105
p_x_given_z_logits = generative_network (z = p_z_sample ,
117
106
hidden_size = FLAGS .hidden_size )
118
- prior_predictive = distributions .Bernoulli (logits = p_x_given_z_logits )
107
+ prior_predictive = tfp . distributions .Bernoulli (logits = p_x_given_z_logits )
119
108
prior_predictive_samples = prior_predictive .sample ()
120
- tf .summary .image ('prior_predictive' ,
109
+ tfc .summary .image ('prior_predictive' ,
121
110
tf .cast (prior_predictive_samples , tf .float32 ))
122
111
123
112
# Take samples from the prior with a placeholder
124
- with tf .variable_scope ('model' , reuse = True ):
113
+ with tfc .variable_scope ('model' , reuse = True ):
125
114
z_input = tf .placeholder (tf .float32 , [None , FLAGS .latent_dim ])
126
115
p_x_given_z_logits = generative_network (z = z_input ,
127
116
hidden_size = FLAGS .hidden_size )
128
- prior_predictive_inp = distributions .Bernoulli (logits = p_x_given_z_logits )
117
+ prior_predictive_inp = tfp . distributions .Bernoulli (logits = p_x_given_z_logits )
129
118
prior_predictive_inp_sample = prior_predictive_inp .sample ()
130
119
131
120
# Build the evidence lower bound (ELBO) or the negative loss
132
- kl = tf .reduce_sum (distributions .kl_divergence (q_z , p_z ), 1 )
121
+ kl = tf .reduce_sum (tfp . distributions .kl_divergence (q_z , p_z ), 1 )
133
122
expected_log_likelihood = tf .reduce_sum (p_x_given_z .log_prob (x ),
134
123
[1 , 2 , 3 ])
135
124
136
125
elbo = tf .reduce_sum (expected_log_likelihood - kl , 0 )
137
-
138
- optimizer = tf .train .RMSPropOptimizer (learning_rate = 0.001 )
139
-
126
+ optimizer = tfc .train .RMSPropOptimizer (learning_rate = 0.001 )
140
127
train_op = optimizer .minimize (- elbo )
141
128
142
129
# Merge all the summaries
143
- summary_op = tf .summary .merge_all ()
130
+ summary_op = tfc .summary .merge_all ()
144
131
145
- init_op = tf .global_variables_initializer ()
132
+ init_op = tfc .global_variables_initializer ()
146
133
147
134
# Run training
148
- sess = tf .InteractiveSession ()
135
+ sess = tfc .InteractiveSession ()
149
136
sess .run (init_op )
150
137
151
- mnist = read_data_sets (FLAGS .data_dir , one_hot = True )
138
+ mnist_data = tfds .load (name = 'binarized_mnist' , split = 'train' , shuffle_files = False )
139
+ dataset = mnist_data .repeat ().shuffle (buffer_size = 1024 ).batch (FLAGS .batch_size )
152
140
153
141
print ('Saving TensorBoard summaries and images to: %s' % FLAGS .logdir )
154
- train_writer = tf .summary .FileWriter (FLAGS .logdir , sess .graph )
155
-
156
- # Get fixed MNIST digits for plotting posterior means during training
157
- np_x_fixed , np_y = mnist .test .next_batch (5000 )
158
- np_x_fixed = np_x_fixed .reshape (5000 , 28 , 28 , 1 )
159
- np_x_fixed = (np_x_fixed > 0.5 ).astype (np .float32 )
142
+ train_writer = tfc .summary .FileWriter (FLAGS .logdir , sess .graph )
160
143
161
144
t0 = time .time ()
162
- for i in range (FLAGS .n_iterations ):
163
- # Re-binarize the data at every batch; this improves results
164
- np_x , _ = mnist .train .next_batch (FLAGS .batch_size )
165
- np_x = np_x .reshape (FLAGS .batch_size , 28 , 28 , 1 )
166
- np_x = (np_x > 0.5 ).astype (np .float32 )
145
+ for i , batch in enumerate (tfds .as_numpy (dataset )):
146
+ np_x = batch ['image' ]
167
147
sess .run (train_op , {x : np_x })
168
-
169
- # Print progress and save samples every so often
170
148
if i % FLAGS .print_every == 0 :
171
149
np_elbo , summary_str = sess .run ([elbo , summary_op ], {x : np_x })
172
150
train_writer .add_summary (summary_str , i )
173
151
print ('Iteration: {0:d} ELBO: {1:.3f} s/iter: {2:.3e}' .format (
174
152
i ,
175
153
np_elbo / FLAGS .batch_size ,
176
154
(time .time () - t0 ) / FLAGS .print_every ))
177
- t0 = time .time ()
178
-
179
155
# Save samples
180
156
np_posterior_samples , np_prior_samples = sess .run (
181
157
[posterior_predictive_samples , prior_predictive_samples ], {x : np_x })
182
158
for k in range (FLAGS .n_samples ):
183
159
f_name = os .path .join (
184
160
FLAGS .logdir , 'iter_%d_posterior_predictive_%d_data.jpg' % (i , k ))
185
- imwrite (f_name , np_x [k , :, :, 0 ])
161
+ imwrite (f_name , np_x [k , :, :, 0 ]. astype ( np . uint8 ) )
186
162
f_name = os .path .join (
187
163
FLAGS .logdir , 'iter_%d_posterior_predictive_%d_sample.jpg' % (i , k ))
188
- imwrite (f_name , np_posterior_samples [k , :, :, 0 ])
164
+ imwrite (f_name , np_posterior_samples [k , :, :, 0 ]. astype ( np . uint8 ) )
189
165
f_name = os .path .join (
190
166
FLAGS .logdir , 'iter_%d_prior_predictive_%d.jpg' % (i , k ))
191
- imwrite (f_name , np_prior_samples [k , :, :, 0 ])
192
-
193
- # Plot the posterior predictive space
194
- if FLAGS .latent_dim == 2 :
195
- np_q_mu = sess .run (q_mu , {x : np_x_fixed })
196
- cmap = mpl .colors .ListedColormap (sns .color_palette ("husl" ))
197
- f , ax = plt .subplots (1 , figsize = (6 * 1.1618 , 6 ))
198
- im = ax .scatter (np_q_mu [:, 0 ], np_q_mu [:, 1 ], c = np .argmax (np_y , 1 ), cmap = cmap ,
199
- alpha = 0.7 )
200
- ax .set_xlabel ('First dimension of sampled latent variable $z_1$' )
201
- ax .set_ylabel ('Second dimension of sampled latent variable mean $z_2$' )
202
- ax .set_xlim ([- 10. , 10. ])
203
- ax .set_ylim ([- 10. , 10. ])
204
- f .colorbar (im , ax = ax , label = 'Digit class' )
205
- plt .tight_layout ()
206
- plt .savefig (os .path .join (FLAGS .logdir ,
207
- 'posterior_predictive_map_frame_%d.png' % i ))
208
- plt .close ()
209
-
210
- nx = ny = 20
211
- x_values = np .linspace (- 3 , 3 , nx )
212
- y_values = np .linspace (- 3 , 3 , ny )
213
- canvas = np .empty ((28 * ny , 28 * nx ))
214
- for ii , yi in enumerate (x_values ):
215
- for j , xi in enumerate (y_values ):
216
- np_z = np .array ([[xi , yi ]])
217
- x_mean = sess .run (prior_predictive_inp_sample , {z_input : np_z })
218
- canvas [(nx - ii - 1 ) * 28 :(nx - ii ) * 28 , j *
219
- 28 :(j + 1 ) * 28 ] = x_mean [0 ].reshape (28 , 28 )
220
- imwrite (os .path .join (FLAGS .logdir ,
221
- 'prior_predictive_map_frame_%d.png' % i ), canvas )
222
- # plt.figure(figsize=(8, 10))
223
- # Xi, Yi = np.meshgrid(x_values, y_values)
224
- # plt.imshow(canvas, origin="upper")
225
- # plt.tight_layout()
226
- # plt.savefig()
227
-
228
- # Make the gifs
229
- if FLAGS .latent_dim == 2 :
230
- os .system (
231
- 'convert -delay 15 -loop 0 {0}/posterior_predictive_map_frame*png {0}/posterior_predictive.gif'
232
- .format (FLAGS .logdir ))
233
- os .system (
234
- 'convert -delay 15 -loop 0 {0}/prior_predictive_map_frame*png {0}/prior_predictive.gif'
235
- .format (FLAGS .logdir ))
236
-
237
-
238
- def main (_ ):
239
- if tf .gfile .Exists (FLAGS .logdir ):
240
- tf .gfile .DeleteRecursively (FLAGS .logdir )
241
- tf .gfile .MakeDirs (FLAGS .logdir )
242
- train ()
243
-
167
+ imwrite (f_name , np_prior_samples [k , :, :, 0 ].astype (np .uint8 ))
168
+ t0 = time .time ()
244
169
245
170
if __name__ == '__main__' :
246
- tf . app . run ()
171
+ train ()
0 commit comments