Skip to content

Commit 55a34ae

Browse files
committed
fixing issues tensorflow#23 and tensorflow#25
1 parent 38079f0 commit 55a34ae

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

autoencoder/VariationalAutoencoderRunner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
1010

1111

12-
def standard_scale(X_train, X_test):
13-
preprocessor = prep.StandardScaler().fit(X_train)
12+
def minmax_scale(X_train, X_test):
13+
preprocessor = prep.MinMaxScaler(feature_range=(0, 1)).fit(X_train)
1414
X_train = preprocessor.transform(X_train)
1515
X_test = preprocessor.transform(X_test)
1616
return X_train, X_test
@@ -21,7 +21,7 @@ def get_random_block_from_data(data, batch_size):
2121
return data[start_index:(start_index + batch_size)]
2222

2323

24-
X_train, X_test = standard_scale(mnist.train.images, mnist.test.images)
24+
X_train, X_test = minmax_scale(mnist.train.images, mnist.test.images)
2525

2626
n_samples = int(mnist.train.num_examples)
2727
training_epochs = 20
@@ -30,8 +30,7 @@ def get_random_block_from_data(data, batch_size):
3030

3131
autoencoder = VariationalAutoencoder(n_input = 784,
3232
n_hidden = 200,
33-
optimizer = tf.train.AdamOptimizer(learning_rate = 0.001),
34-
gaussian_sample_size = 128)
33+
optimizer = tf.train.AdamOptimizer(learning_rate = 0.001))
3534

3635
for epoch in range(training_epochs):
3736
avg_cost = 0.

autoencoder/autoencoder_models/VariationalAutoencoder.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44

55
class VariationalAutoencoder(object):
66

7-
def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer(),
8-
gaussian_sample_size = 128):
7+
def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer()):
98
self.n_input = n_input
109
self.n_hidden = n_hidden
11-
self.gaussian_sample_size = gaussian_sample_size
1210

1311
network_weights = self._initialize_weights()
1412
self.weights = network_weights
@@ -18,14 +16,12 @@ def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer(),
1816
self.z_mean = tf.add(tf.matmul(self.x, self.weights['w1']), self.weights['b1'])
1917
self.z_log_sigma_sq = tf.add(tf.matmul(self.x, self.weights['log_sigma_w1']), self.weights['log_sigma_b1'])
2018

21-
2219
# sample from gaussian distribution
23-
eps = tf.random_normal((self.gaussian_sample_size, n_hidden), 0, 1, dtype = tf.float32)
20+
eps = tf.random_normal(tf.pack([tf.shape(self.x)[0], self.n_hidden]), 0, 1, dtype = tf.float32)
2421
self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps))
2522

2623
self.reconstruction = tf.add(tf.matmul(self.z, self.weights['w2']), self.weights['b2'])
2724

28-
2925
# cost
3026
reconstr_loss = 0.5 * tf.reduce_sum(tf.pow(tf.sub(self.reconstruction, self.x), 2.0))
3127
latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq
@@ -38,7 +34,6 @@ def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer(),
3834
self.sess = tf.Session()
3935
self.sess.run(init)
4036

41-
4237
def _initialize_weights(self):
4338
all_weights = dict()
4439
all_weights['w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden))

0 commit comments

Comments
 (0)