Skip to content

Commit b43325e

Browse files
author
Jaan Altosaar
committed
upgrade to tf 1.1.4; slim <- keras; tf.distributions <- tf.probability; tf <- tf.compat.v1
1 parent 96337fa commit b43325e

File tree

2 files changed

+122
-123
lines changed

2 files changed

+122
-123
lines changed

environment.yml

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
name: dev
2+
channels:
3+
- defaults
4+
dependencies:
5+
- blas=1.0=mkl
6+
- ca-certificates=2019.5.15=1
7+
- certifi=2019.6.16=py37_1
8+
- freetype=2.9.1=hb4e5f40_0
9+
- imageio=2.5.0=py37_0
10+
- intel-openmp=2019.4=233
11+
- jpeg=9b=he5867d9_2
12+
- libcxx=4.0.1=hcfea43d_1
13+
- libcxxabi=4.0.1=hcfea43d_1
14+
- libedit=3.1.20181209=hb402a30_0
15+
- libffi=3.2.1=h475c297_4
16+
- libgfortran=3.0.1=h93005f0_2
17+
- libpng=1.6.37=ha441bb4_0
18+
- libtiff=4.0.10=hcb84e12_2
19+
- mkl=2019.4=233
20+
- mkl-service=2.3.0=py37hfbe908c_0
21+
- mkl_fft=1.0.14=py37h5e564d8_0
22+
- mkl_random=1.0.2=py37h27c97d8_0
23+
- ncurses=6.1=h0a44026_1
24+
- numpy=1.16.5=py37hacdab7b_0
25+
- numpy-base=1.16.5=py37h6575580_0
26+
- olefile=0.46=py37_0
27+
- openssl=1.1.1d=h1de35cc_1
28+
- pillow=6.1.0=py37hb68e598_0
29+
- python=3.7.4=h359304d_1
30+
- readline=7.0=h1de35cc_5
31+
- setuptools=41.0.1=py37_0
32+
- six=1.12.0=py37_0
33+
- sqlite=3.29.0=ha441bb4_0
34+
- tk=8.6.8=ha441bb4_0
35+
- wheel=0.33.4=py37_0
36+
- xz=5.2.4=h1de35cc_4
37+
- zlib=1.2.11=h1de35cc_3
38+
- zstd=1.3.7=h5bba6e5_0
39+
- pip:
40+
- absl-py==0.8.0
41+
- astor==0.8.0
42+
- attrs==19.1.0
43+
- chardet==3.0.4
44+
- cloudpickle==1.2.2
45+
- decorator==4.4.0
46+
- dill==0.3.0
47+
- future==0.17.1
48+
- gast==0.3.2
49+
- google-pasta==0.1.7
50+
- googleapis-common-protos==1.6.0
51+
- grpcio==1.23.0
52+
- h5py==2.10.0
53+
- idna==2.8
54+
- keras-applications==1.0.8
55+
- keras-preprocessing==1.1.0
56+
- markdown==3.1.1
57+
- pip==19.2.3
58+
- promise==2.2.1
59+
- protobuf==3.9.1
60+
- psutil==5.6.3
61+
- requests==2.22.0
62+
- tensorboard==1.14.0
63+
- tensorflow==1.14.0
64+
- tensorflow-datasets==1.2.0
65+
- tensorflow-estimator==1.14.0
66+
- tensorflow-metadata==0.14.0
67+
- tensorflow-probability==0.7.0
68+
- termcolor==1.1.0
69+
- tqdm==4.36.0
70+
- urllib3==1.25.3
71+
- werkzeug==0.15.6
72+
- wrapt==1.11.2
73+
prefix: /usr/local/anaconda3/envs/dev
74+
+48-123
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,20 @@
11
import itertools
2-
import matplotlib as mpl
32
import numpy as np
43
import os
54
import tensorflow as tf
5+
import tensorflow.keras as tfk
66
import tensorflow.contrib.slim as slim
77
import time
8-
import seaborn as sns
9-
10-
from matplotlib import pyplot as plt
8+
import tensorflow_datasets as tfds
9+
import tensorflow_probability as tfp
1110
from imageio import imwrite
1211
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
13-
14-
sns.set_style('whitegrid')
15-
16-
distributions = tf.distributions
12+
tfkl = tfk.layers
13+
tfc = tf.compat.v1
1714

1815
flags = tf.app.flags
1916
flags.DEFINE_string('data_dir', '/tmp/dat/', 'Directory for data')
2017
flags.DEFINE_string('logdir', '/tmp/log/', 'Directory for logs')
21-
22-
# For making plots:
23-
# flags.DEFINE_integer('latent_dim', 2, 'Latent dimensionality of model')
24-
# flags.DEFINE_integer('batch_size', 64, 'Minibatch size')
25-
# flags.DEFINE_integer('n_samples', 10, 'Number of samples to save')
26-
# flags.DEFINE_integer('print_every', 10, 'Print every n iterations')
27-
# flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks')
28-
# flags.DEFINE_integer('n_iterations', 1000, 'number of iterations')
29-
30-
# For bigger model:
3118
flags.DEFINE_integer('latent_dim', 100, 'Latent dimensionality of model')
3219
flags.DEFINE_integer('batch_size', 64, 'Minibatch size')
3320
flags.DEFINE_integer('n_samples', 1, 'Number of samples to save')
@@ -50,12 +37,13 @@ def inference_network(x, latent_dim, hidden_size):
5037
mu: Mean parameters for the variational family Normal
5138
sigma: Standard deviation parameters for the variational family Normal
5239
"""
53-
with slim.arg_scope([slim.fully_connected], activation_fn=tf.nn.relu):
54-
net = slim.flatten(x)
55-
net = slim.fully_connected(net, hidden_size)
56-
net = slim.fully_connected(net, hidden_size)
57-
gaussian_params = slim.fully_connected(
58-
net, latent_dim * 2, activation_fn=None)
40+
inference_net = tfk.Sequential([
41+
tfkl.Flatten(),
42+
tfkl.Dense(hidden_size, activation=tf.nn.relu),
43+
tfkl.Dense(hidden_size, activation=tf.nn.relu),
44+
tfkl.Dense(latent_dim * 2, activation=None)
45+
])
46+
gaussian_params = inference_net(x)
5947
# The mean parameter is unconstrained
6048
mu = gaussian_params[:, :latent_dim]
6149
# The standard deviation must be positive. Parametrize with a softplus
@@ -73,174 +61,111 @@ def generative_network(z, hidden_size):
7361
Returns:
7462
bernoulli_logits: logits for the Bernoulli likelihood of the data
7563
"""
76-
with slim.arg_scope([slim.fully_connected], activation_fn=tf.nn.relu):
77-
net = slim.fully_connected(z, hidden_size)
78-
net = slim.fully_connected(net, hidden_size)
79-
bernoulli_logits = slim.fully_connected(net, 784, activation_fn=None)
80-
bernoulli_logits = tf.reshape(bernoulli_logits, [-1, 28, 28, 1])
81-
return bernoulli_logits
64+
generative_net = tfk.Sequential([
65+
tfkl.Dense(hidden_size, activation=tf.nn.relu),
66+
tfkl.Dense(hidden_size, activation=tf.nn.relu),
67+
tfkl.Dense(28 * 28, activation=None)
68+
])
69+
bernoulli_logits = generative_net(z)
70+
return tf.reshape(bernoulli_logits, [-1, 28, 28, 1])
8271

8372

8473
def train():
8574
# Train a Variational Autoencoder on MNIST
8675

8776
# Input placeholders
8877
with tf.name_scope('data'):
89-
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
90-
tf.summary.image('data', x)
78+
x = tfc.placeholder(tf.float32, [None, 28, 28, 1])
79+
tfc.summary.image('data', x)
9180

92-
with tf.variable_scope('variational'):
81+
with tfc.variable_scope('variational'):
9382
q_mu, q_sigma = inference_network(x=x,
9483
latent_dim=FLAGS.latent_dim,
9584
hidden_size=FLAGS.hidden_size)
9685
# The variational distribution is a Normal with mean and standard
9786
# deviation given by the inference network
98-
q_z = distributions.Normal(loc=q_mu, scale=q_sigma)
99-
assert q_z.reparameterization_type == distributions.FULLY_REPARAMETERIZED
87+
q_z = tfp.distributions.Normal(loc=q_mu, scale=q_sigma)
88+
assert q_z.reparameterization_type == tfp.distributions.FULLY_REPARAMETERIZED
10089

101-
with tf.variable_scope('model'):
90+
with tfc.variable_scope('model'):
10291
# The likelihood is Bernoulli-distributed with logits given by the
10392
# generative network
10493
p_x_given_z_logits = generative_network(z=q_z.sample(),
10594
hidden_size=FLAGS.hidden_size)
106-
p_x_given_z = distributions.Bernoulli(logits=p_x_given_z_logits)
95+
p_x_given_z = tfp.distributions.Bernoulli(logits=p_x_given_z_logits)
10796
posterior_predictive_samples = p_x_given_z.sample()
108-
tf.summary.image('posterior_predictive',
97+
tfc.summary.image('posterior_predictive',
10998
tf.cast(posterior_predictive_samples, tf.float32))
11099

111100
# Take samples from the prior
112-
with tf.variable_scope('model', reuse=True):
113-
p_z = distributions.Normal(loc=np.zeros(FLAGS.latent_dim, dtype=np.float32),
101+
with tfc.variable_scope('model', reuse=True):
102+
p_z = tfp.distributions.Normal(loc=np.zeros(FLAGS.latent_dim, dtype=np.float32),
114103
scale=np.ones(FLAGS.latent_dim, dtype=np.float32))
115104
p_z_sample = p_z.sample(FLAGS.n_samples)
116105
p_x_given_z_logits = generative_network(z=p_z_sample,
117106
hidden_size=FLAGS.hidden_size)
118-
prior_predictive = distributions.Bernoulli(logits=p_x_given_z_logits)
107+
prior_predictive = tfp.distributions.Bernoulli(logits=p_x_given_z_logits)
119108
prior_predictive_samples = prior_predictive.sample()
120-
tf.summary.image('prior_predictive',
109+
tfc.summary.image('prior_predictive',
121110
tf.cast(prior_predictive_samples, tf.float32))
122111

123112
# Take samples from the prior with a placeholder
124-
with tf.variable_scope('model', reuse=True):
113+
with tfc.variable_scope('model', reuse=True):
125114
z_input = tf.placeholder(tf.float32, [None, FLAGS.latent_dim])
126115
p_x_given_z_logits = generative_network(z=z_input,
127116
hidden_size=FLAGS.hidden_size)
128-
prior_predictive_inp = distributions.Bernoulli(logits=p_x_given_z_logits)
117+
prior_predictive_inp = tfp.distributions.Bernoulli(logits=p_x_given_z_logits)
129118
prior_predictive_inp_sample = prior_predictive_inp.sample()
130119

131120
# Build the evidence lower bound (ELBO) or the negative loss
132-
kl = tf.reduce_sum(distributions.kl_divergence(q_z, p_z), 1)
121+
kl = tf.reduce_sum(tfp.distributions.kl_divergence(q_z, p_z), 1)
133122
expected_log_likelihood = tf.reduce_sum(p_x_given_z.log_prob(x),
134123
[1, 2, 3])
135124

136125
elbo = tf.reduce_sum(expected_log_likelihood - kl, 0)
137-
138-
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001)
139-
126+
optimizer = tfc.train.RMSPropOptimizer(learning_rate=0.001)
140127
train_op = optimizer.minimize(-elbo)
141128

142129
# Merge all the summaries
143-
summary_op = tf.summary.merge_all()
130+
summary_op = tfc.summary.merge_all()
144131

145-
init_op = tf.global_variables_initializer()
132+
init_op = tfc.global_variables_initializer()
146133

147134
# Run training
148-
sess = tf.InteractiveSession()
135+
sess = tfc.InteractiveSession()
149136
sess.run(init_op)
150137

151-
mnist = read_data_sets(FLAGS.data_dir, one_hot=True)
138+
mnist_data = tfds.load(name='binarized_mnist', split='train', shuffle_files=False)
139+
dataset = mnist_data.repeat().shuffle(buffer_size=1024).batch(FLAGS.batch_size)
152140

153141
print('Saving TensorBoard summaries and images to: %s' % FLAGS.logdir)
154-
train_writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)
155-
156-
# Get fixed MNIST digits for plotting posterior means during training
157-
np_x_fixed, np_y = mnist.test.next_batch(5000)
158-
np_x_fixed = np_x_fixed.reshape(5000, 28, 28, 1)
159-
np_x_fixed = (np_x_fixed > 0.5).astype(np.float32)
142+
train_writer = tfc.summary.FileWriter(FLAGS.logdir, sess.graph)
160143

161144
t0 = time.time()
162-
for i in range(FLAGS.n_iterations):
163-
# Re-binarize the data at every batch; this improves results
164-
np_x, _ = mnist.train.next_batch(FLAGS.batch_size)
165-
np_x = np_x.reshape(FLAGS.batch_size, 28, 28, 1)
166-
np_x = (np_x > 0.5).astype(np.float32)
145+
for i, batch in enumerate(tfds.as_numpy(dataset)):
146+
np_x = batch['image']
167147
sess.run(train_op, {x: np_x})
168-
169-
# Print progress and save samples every so often
170148
if i % FLAGS.print_every == 0:
171149
np_elbo, summary_str = sess.run([elbo, summary_op], {x: np_x})
172150
train_writer.add_summary(summary_str, i)
173151
print('Iteration: {0:d} ELBO: {1:.3f} s/iter: {2:.3e}'.format(
174152
i,
175153
np_elbo / FLAGS.batch_size,
176154
(time.time() - t0) / FLAGS.print_every))
177-
t0 = time.time()
178-
179155
# Save samples
180156
np_posterior_samples, np_prior_samples = sess.run(
181157
[posterior_predictive_samples, prior_predictive_samples], {x: np_x})
182158
for k in range(FLAGS.n_samples):
183159
f_name = os.path.join(
184160
FLAGS.logdir, 'iter_%d_posterior_predictive_%d_data.jpg' % (i, k))
185-
imwrite(f_name, np_x[k, :, :, 0])
161+
imwrite(f_name, np_x[k, :, :, 0].astype(np.uint8))
186162
f_name = os.path.join(
187163
FLAGS.logdir, 'iter_%d_posterior_predictive_%d_sample.jpg' % (i, k))
188-
imwrite(f_name, np_posterior_samples[k, :, :, 0])
164+
imwrite(f_name, np_posterior_samples[k, :, :, 0].astype(np.uint8))
189165
f_name = os.path.join(
190166
FLAGS.logdir, 'iter_%d_prior_predictive_%d.jpg' % (i, k))
191-
imwrite(f_name, np_prior_samples[k, :, :, 0])
192-
193-
# Plot the posterior predictive space
194-
if FLAGS.latent_dim == 2:
195-
np_q_mu = sess.run(q_mu, {x: np_x_fixed})
196-
cmap = mpl.colors.ListedColormap(sns.color_palette("husl"))
197-
f, ax = plt.subplots(1, figsize=(6 * 1.1618, 6))
198-
im = ax.scatter(np_q_mu[:, 0], np_q_mu[:, 1], c=np.argmax(np_y, 1), cmap=cmap,
199-
alpha=0.7)
200-
ax.set_xlabel('First dimension of sampled latent variable $z_1$')
201-
ax.set_ylabel('Second dimension of sampled latent variable mean $z_2$')
202-
ax.set_xlim([-10., 10.])
203-
ax.set_ylim([-10., 10.])
204-
f.colorbar(im, ax=ax, label='Digit class')
205-
plt.tight_layout()
206-
plt.savefig(os.path.join(FLAGS.logdir,
207-
'posterior_predictive_map_frame_%d.png' % i))
208-
plt.close()
209-
210-
nx = ny = 20
211-
x_values = np.linspace(-3, 3, nx)
212-
y_values = np.linspace(-3, 3, ny)
213-
canvas = np.empty((28 * ny, 28 * nx))
214-
for ii, yi in enumerate(x_values):
215-
for j, xi in enumerate(y_values):
216-
np_z = np.array([[xi, yi]])
217-
x_mean = sess.run(prior_predictive_inp_sample, {z_input: np_z})
218-
canvas[(nx - ii - 1) * 28:(nx - ii) * 28, j *
219-
28:(j + 1) * 28] = x_mean[0].reshape(28, 28)
220-
imwrite(os.path.join(FLAGS.logdir,
221-
'prior_predictive_map_frame_%d.png' % i), canvas)
222-
# plt.figure(figsize=(8, 10))
223-
# Xi, Yi = np.meshgrid(x_values, y_values)
224-
# plt.imshow(canvas, origin="upper")
225-
# plt.tight_layout()
226-
# plt.savefig()
227-
228-
# Make the gifs
229-
if FLAGS.latent_dim == 2:
230-
os.system(
231-
'convert -delay 15 -loop 0 {0}/posterior_predictive_map_frame*png {0}/posterior_predictive.gif'
232-
.format(FLAGS.logdir))
233-
os.system(
234-
'convert -delay 15 -loop 0 {0}/prior_predictive_map_frame*png {0}/prior_predictive.gif'
235-
.format(FLAGS.logdir))
236-
237-
238-
def main(_):
239-
if tf.gfile.Exists(FLAGS.logdir):
240-
tf.gfile.DeleteRecursively(FLAGS.logdir)
241-
tf.gfile.MakeDirs(FLAGS.logdir)
242-
train()
243-
167+
imwrite(f_name, np_prior_samples[k, :, :, 0].astype(np.uint8))
168+
t0 = time.time()
244169

245170
if __name__ == '__main__':
246-
tf.app.run()
171+
train()

0 commit comments

Comments
 (0)