Skip to content
This repository was archived by the owner on Oct 19, 2019. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
class Config:
def __init__(self):
root = self.Scope('')
for k, v in FLAGS.__dict__['__flags'].iteritems():
for k, v in FLAGS.__dict__['__flags'].items():
root[k] = v
self.stack = [ root ]

def iteritems(self):
return self.to_dict().iteritems()
return self.to_dict().items()

def to_dict(self):
self._pop_stale()
Expand Down
8 changes: 4 additions & 4 deletions resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def inference_small(x,
c['fc_units_out'] = num_classes
c['num_blocks'] = num_blocks
c['num_classes'] = num_classes
inference_small_config(x, c)
return inference_small_config(x, c)

def inference_small_config(x, c):
c['bottleneck'] = False
Expand Down Expand Up @@ -145,13 +145,13 @@ def _imagenet_preprocess(rgb):


def loss(logits, labels):
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
cross_entropy_mean = tf.reduce_mean(cross_entropy)

regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

loss_ = tf.add_n([cross_entropy_mean] + regularization_losses)
tf.scalar_summary('loss', loss_)
tf.summary.scalar('loss', loss_)

return loss_

Expand Down Expand Up @@ -300,7 +300,7 @@ def _get_variable(name,
regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
else:
regularizer = None
collections = [tf.GraphKeys.VARIABLES, RESNET_VARIABLES]
collections = [tf.GraphKeys.GLOBAL_VARIABLES, RESNET_VARIABLES]
return tf.get_variable(name,
shape=shape,
initializer=initializer,
Expand Down
20 changes: 10 additions & 10 deletions resnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def train(is_training, logits, images, labels):
# loss_avg
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
tf.scalar_summary('loss_avg', ema.average(loss_))
tf.summary.scalar('loss_avg', ema.average(loss_))

# validation stats
ema = tf.train.ExponentialMovingAverage(0.9, val_step)
val_op = tf.group(val_step.assign_add(1), ema.apply([top1_error]))
top1_error_avg = ema.average(top1_error)
tf.scalar_summary('val_top1_error_avg', top1_error_avg)
tf.summary.scalar('val_top1_error_avg', top1_error_avg)

tf.scalar_summary('learning_rate', FLAGS.learning_rate)
tf.summary.scalar('learning_rate', FLAGS.learning_rate)

opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
grads = opt.compute_gradients(loss_)
Expand All @@ -67,27 +67,27 @@ def train(is_training, logits, images, labels):
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

saver = tf.train.Saver(tf.all_variables())
saver = tf.train.Saver(tf.global_variables())

summary_op = tf.merge_all_summaries()
summary_op = tf.summary.merge_all()

init = tf.initialize_all_variables()
init = tf.global_variables_initializer()

sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
sess.run(init)
tf.train.start_queue_runners(sess=sess)

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

if FLAGS.resume:
latest = tf.train.latest_checkpoint(FLAGS.train_dir)
if not latest:
print "No checkpoint to continue from in", FLAGS.train_dir
print("No checkpoint to continue from in", FLAGS.train_dir)
sys.exit(1)
print "resume", latest
print("resume", latest)
saver.restore(sess, latest)

for x in xrange(FLAGS.max_steps + 1):
for x in range(FLAGS.max_steps + 1):
start_time = time.time()

step = sess.run(global_step)
Expand Down
14 changes: 7 additions & 7 deletions train_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def distorted_inputs(data_dir, batch_size):
distorted_image, lower=0.2, upper=1.8)

# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_whitening(distorted_image)
float_image = tf.image.per_image_standardization(distorted_image)

# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
Expand Down Expand Up @@ -250,7 +250,7 @@ def inputs(eval_data, data_dir, batch_size):
width, height)

# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_whitening(resized_image)
float_image = tf.image.per_image_standardization(resized_image)

# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
Expand Down Expand Up @@ -299,11 +299,11 @@ def main(argv=None): # pylint: disable=unused-argument
lambda: (images_train, labels_train),
lambda: (images_val, labels_val))

logits = inference_small(images,
num_classes=10,
is_training=is_training,
use_bias=(not FLAGS.use_bn),
num_blocks=3)
logits = inference_small(images,
num_classes=10,
is_training=is_training,
use_bias=(not FLAGS.use_bn),
num_blocks=3)
train(is_training, logits, images, labels)


Expand Down