|
| 1 | +""" Solution for simple logistic regression model for MNIST |
| 2 | +with placeholder |
| 3 | +MNIST dataset: yann.lecun.com/exdb/mnist/ |
| 4 | +Created by Chip Huyen ([email protected]) |
| 5 | +CS20: "TensorFlow for Deep Learning Research" |
| 6 | +cs20.stanford.edu |
| 7 | +Lecture 03 |
| 8 | +""" |
| 9 | +import os |
| 10 | +os.environ['TF_CPP_MIN_LOG_LEVEL']='2' |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import tensorflow as tf |
| 14 | +from tensorflow.examples.tutorials.mnist import input_data |
| 15 | +import time |
| 16 | + |
| 17 | +import utils |
| 18 | + |
| 19 | +# Define paramaters for the model |
| 20 | +learning_rate = 0.01 |
| 21 | +batch_size = 128 |
| 22 | +n_epochs = 30 |
| 23 | + |
| 24 | +# Step 1: Read in data |
| 25 | +# using TF Learn's built in function to load MNIST data to the folder data/mnist |
| 26 | +mnist = input_data.read_data_sets('data/mnist', one_hot=True) |
| 27 | +X_batch, Y_batch = mnist.train.next_batch(batch_size) |
| 28 | + |
| 29 | +# Step 2: create placeholders for features and labels |
| 30 | +# each image in the MNIST data is of shape 28*28 = 784 |
| 31 | +# therefore, each image is represented with a 1x784 tensor |
| 32 | +# there are 10 classes for each image, corresponding to digits 0 - 9. |
| 33 | +# each lable is one hot vector. |
| 34 | +X = tf.placeholder(tf.float32, [batch_size, 784], name='image') |
| 35 | +Y = tf.placeholder(tf.int32, [batch_size, 10], name='label') |
| 36 | + |
| 37 | +# Step 3: create weights and bias |
| 38 | +# w is initialized to random variables with mean of 0, stddev of 0.01 |
| 39 | +# b is initialized to 0 |
| 40 | +# shape of w depends on the dimension of X and Y so that Y = tf.matmul(X, w) |
| 41 | +# shape of b depends on Y |
| 42 | +w = tf.get_variable(name='weights', shape=(784, 10), initializer=tf.random_normal_initializer()) |
| 43 | +b = tf.get_variable(name='bias', shape=(1, 10), initializer=tf.zeros_initializer()) |
| 44 | + |
| 45 | +# Step 4: build model |
| 46 | +# the model that returns the logits. |
| 47 | +# this logits will be later passed through softmax layer |
| 48 | +logits = tf.matmul(X, w) + b |
| 49 | + |
| 50 | +# Step 5: define loss function |
| 51 | +# use cross entropy of softmax of logits as the loss function |
| 52 | +entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y, name='loss') |
| 53 | +loss = tf.reduce_mean(entropy) # computes the mean over all the examples in the batch |
| 54 | +# loss = tf.reduce_mean(-tf.reduce_sum(tf.nn.softmax(logits) * tf.log(Y), reduction_indices=[1])) |
| 55 | + |
| 56 | +# Step 6: define training op |
| 57 | +# using gradient descent with learning rate of 0.01 to minimize loss |
| 58 | +optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss) |
| 59 | + |
| 60 | +# Step 7: calculate accuracy with test set |
| 61 | +preds = tf.nn.softmax(logits) |
| 62 | +correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y, 1)) |
| 63 | +accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32)) |
| 64 | + |
| 65 | +writer = tf.summary.FileWriter('./graphs/logreg_placeholder', tf.get_default_graph()) |
| 66 | +with tf.Session() as sess: |
| 67 | + start_time = time.time() |
| 68 | + sess.run(tf.global_variables_initializer()) |
| 69 | + n_batches = int(mnist.train.num_examples/batch_size) |
| 70 | + |
| 71 | + # train the model n_epochs times |
| 72 | + for i in range(n_epochs): |
| 73 | + total_loss = 0 |
| 74 | + |
| 75 | + for j in range(n_batches): |
| 76 | + X_batch, Y_batch = mnist.train.next_batch(batch_size) |
| 77 | + _, loss_batch = sess.run([optimizer, loss], {X: X_batch, Y:Y_batch}) |
| 78 | + total_loss += loss_batch |
| 79 | + print('Average loss epoch {0}: {1}'.format(i, total_loss/n_batches)) |
| 80 | + print('Total time: {0} seconds'.format(time.time() - start_time)) |
| 81 | + |
| 82 | + # test the model |
| 83 | + n_batches = int(mnist.test.num_examples/batch_size) |
| 84 | + total_correct_preds = 0 |
| 85 | + |
| 86 | + for i in range(n_batches): |
| 87 | + X_batch, Y_batch = mnist.test.next_batch(batch_size) |
| 88 | + accuracy_batch = sess.run(accuracy, {X: X_batch, Y:Y_batch}) |
| 89 | + total_correct_preds += accuracy_batch |
| 90 | + |
| 91 | + print('Accuracy {0}'.format(total_correct_preds/mnist.test.num_examples)) |
| 92 | + |
| 93 | +writer.close() |
0 commit comments