Skip to content

Commit

Permalink
parallel with convergence tested
Browse files Browse the repository at this point in the history
  • Loading branch information
Harsh Nilesh Pathak authored Jun 28, 2018
1 parent 7c8d059 commit b372abf
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions lib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,19 +447,18 @@ def SRGAN(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]):
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair)
incr_global_step = tf.assign(global_step, global_step + 1)

#scope.reuse_variables()
with tf.variable_scope('dicriminator_train',reuse=tf.AUTO_REUSE):
with tf.variable_scope('dicriminator_train',reuse=tf.AUTO_REUSE):
discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars)
discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars)
scope.reuse_variables()
#discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars)
#scope.reuse_variables()
tower_grads_d.append(discrim_grads_and_vars)

scope.reuse_variables()
with tf.variable_scope('generator_train'):
# Need to wait discriminator to perform train step
with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
Expand All @@ -470,10 +469,6 @@ def SRGAN(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]):
exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss])

#generator aggregation
avg_grads = average_gradients(tower_grads)
gen_train = gen_optimizer.apply_gradients(avg_grads)
all_outputs_g = tf.concat(tower_outputs, axis=0)

#discriminator aggregation
avg_grads_d = average_gradients(tower_grads_d)
Expand All @@ -482,6 +477,14 @@ def SRGAN(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]):
all_outputs_real_d = tf.concat(tower_outputs_real_d, axis=0)
all_outputs_fake_d = tf.concat(tower_outputs_fake_d, axis=0)

with tf.control_dependencies([discrim_train]):
#generator aggregation
avg_grads = average_gradients(tower_grads)
gen_train = gen_optimizer.apply_gradients(avg_grads)

all_outputs_g = tf.concat(tower_outputs, axis=0)


return Network(
discrim_real_output = all_outputs_real_d,
discrim_fake_output = all_outputs_fake_d,
Expand Down

0 comments on commit b372abf

Please sign in to comment.