Skip to content

Commit

Permalink
attention layer added
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh306 authored Jun 29, 2018
1 parent bdbe76e commit a947f8d
Showing 1 changed file with 91 additions and 138 deletions.
229 changes: 91 additions & 138 deletions lib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ def discriminator_block(inputs, output_channel, kernel_size, stride, scope):
def VGG19_slim(input, type, reuse, scope):
# Define the feature to extract according to the type of perceptual
if type == 'VGG54':
target_layer = 'vgg_19/conv5/conv5_4'
target_layer = scope + 'vgg_19/conv5/conv5_4'
elif type == 'VGG22':
target_layer = 'vgg_19/conv2/conv2_2'
target_layer = scope + 'vgg_19/conv2/conv2_2'
else:
raise NotImplementedError('Unknown perceptual type')
_, output = vgg_19(input, is_training=False, reuse=reuse)
Expand All @@ -344,169 +344,121 @@ def VGG19_slim(input, type, reuse, scope):


# Define the whole network architecture
def SRGAN(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]):
def SRGAN(inputs, targets, FLAGS):
# Define the container of the parameter
Network = collections.namedtuple('Network', 'discrim_real_output, discrim_fake_output, discrim_loss, \
discrim_grads_and_vars, adversarial_loss, content_loss, gen_grads_and_vars, gen_output, train, global_step, \
learning_rate')

#generator tower lists
tower_grads = []
tower_outputs = []
#discriminator tower lists
tower_grads_d = []
tower_outputs_real_d = []
tower_outputs_fake_d = []
tower_discriminator_global = []

with tf.device('/gpu:0'):
split_inputs = tf.split(inputs, len(devices), axis=0)
split_targets = tf.split(targets, len(devices), axis=0)
# Define the learning rate and global step
with tf.variable_scope('get_learning_rate_and_global_step'):
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate,
staircase=FLAGS.stair)
incr_global_step = tf.assign(global_step, global_step + 1)


with tf.variable_scope(tf.get_variable_scope()) as scope:
for i, (inputs, targets, dev) in enumerate(zip(split_inputs, split_targets, devices)):
with tf.device(dev):
with tf.name_scope('tower%d'%i):
# Build the generator part
with tf.variable_scope('generator'):
output_channel = targets.get_shape().as_list()[-1]
gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS)
gen_output.set_shape([FLAGS.batch_size/len(devices) , FLAGS.crop_size*4, FLAGS.crop_size*4, 3])

tower_outputs.append(gen_output)

# Build the fake discriminator
with tf.name_scope('fake_discriminator'):
with tf.variable_scope('discriminator', reuse=False):
discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS)
tower_outputs_fake_d.append(discrim_fake_output)
# Build the generator part
with tf.variable_scope('generator'):
output_channel = targets.get_shape().as_list()[-1]
gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS)
gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3])

# Build the fake discriminator
with tf.name_scope('fake_discriminator'):
with tf.variable_scope('discriminator', reuse=False):
discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS)

# Build the real discriminator
with tf.name_scope('real_discriminator'):
with tf.variable_scope('discriminator', reuse=True):
discrim_real_output = discriminator(targets, FLAGS=FLAGS)

# Use the VGG54 feature
if FLAGS.perceptual_mode == 'VGG54':
with tf.name_scope('vgg19_1') as scope:
extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
with tf.name_scope('vgg19_2') as scope:
extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)

# Use the VGG22 feature
elif FLAGS.perceptual_mode == 'VGG22':
with tf.name_scope('vgg19_1') as scope:
extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
with tf.name_scope('vgg19_2') as scope:
extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)

# Use MSE loss directly
elif FLAGS.perceptual_mode == 'MSE':
extracted_feature_gen = gen_output
extracted_feature_target = targets

# Build the real discriminator
with tf.name_scope('real_discriminator'):
with tf.variable_scope('discriminator', reuse=True):
discrim_real_output = discriminator(targets, FLAGS=FLAGS)
tower_outputs_real_d.append(discrim_real_output)

# Use the VGG54 feature
if FLAGS.perceptual_mode == 'VGG54':
with tf.name_scope('vgg19_1') as scope:
extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
with tf.name_scope('vgg19_2') as scope:
extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)

# Use the VGG22 feature
elif FLAGS.perceptual_mode == 'VGG22':
with tf.name_scope('vgg19_1') as scope:
extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
with tf.name_scope('vgg19_2') as scope:
extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)

# Use MSE loss directly
elif FLAGS.perceptual_mode == 'MSE':
extracted_feature_gen = gen_output
extracted_feature_target = targets

else:
raise NotImplementedError('Unknown perceptual type!!')

# Calculating the generator loss
with tf.variable_scope('generator_loss'):
# Content loss
with tf.variable_scope('content_loss'):
# Compute the euclidean distance between the two features
diff = extracted_feature_gen - extracted_feature_target
if FLAGS.perceptual_mode == 'MSE':
content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
else:
content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))

with tf.variable_scope('adversarial_loss'):
adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS))

gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss
print(adversarial_loss.get_shape())
print(content_loss.get_shape())

# Calculating the discriminator loss
with tf.variable_scope('discriminator_loss'):
discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS)
discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS)

discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss))

# Define the learning rate and global step
with tf.variable_scope('get_learning_rate_and_global_step'):
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair)
incr_global_step = tf.assign(global_step, global_step + 1)

with tf.variable_scope('dicriminator_train',reuse=tf.AUTO_REUSE):
discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars)
#discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars)
#scope.reuse_variables()
tower_grads_d.append(discrim_grads_and_vars)

scope.reuse_variables()
with tf.variable_scope('generator_train'):
# Need to wait discriminator to perform train step
with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
#gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)
tower_grads.append(gen_grads_and_vars)
else:
raise NotImplementedError('Unknown perceptual type!!')

# Calculating the generator loss
with tf.variable_scope('generator_loss'):
# Content loss
with tf.variable_scope('content_loss'):
# Compute the euclidean distance between the two features
diff = extracted_feature_gen - extracted_feature_target
if FLAGS.perceptual_mode == 'MSE':
content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
else:
content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))

with tf.variable_scope('adversarial_loss'):
adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS))

gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss
print(adversarial_loss.get_shape())
print(content_loss.get_shape())

# Calculating the discriminator loss
with tf.variable_scope('discriminator_loss'):
discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS)
discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS)

discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss))

# Define the learning rate and global step
with tf.variable_scope('get_learning_rate_and_global_step'):
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair)
incr_global_step = tf.assign(global_step, global_step + 1)

with tf.variable_scope('dicriminator_train'):
discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars)
discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars)

with tf.variable_scope('generator_train'):
# Need to wait discriminator to perform train step
with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)

#[ToDo] If we do not use moving average on loss??
exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss])


#discriminator aggregation
avg_grads_d = average_gradients(tower_grads_d)
discrim_train = discrim_optimizer.apply_gradients(avg_grads_d)

all_outputs_real_d = tf.concat(tower_outputs_real_d, axis=0)
all_outputs_fake_d = tf.concat(tower_outputs_fake_d, axis=0)

with tf.control_dependencies([discrim_train]):
#generator aggregation
avg_grads = average_gradients(tower_grads)
gen_train = gen_optimizer.apply_gradients(avg_grads)

all_outputs_g = tf.concat(tower_outputs, axis=0)


return Network(
discrim_real_output = all_outputs_real_d,
discrim_fake_output = all_outputs_fake_d,
discrim_real_output = discrim_real_output,
discrim_fake_output = discrim_fake_output,
discrim_loss = exp_averager.average(discrim_loss),
discrim_grads_and_vars = discrim_grads_and_vars,
adversarial_loss = exp_averager.average(adversarial_loss),
content_loss = exp_averager.average(content_loss),
gen_grads_and_vars = gen_grads_and_vars,
gen_output = all_outputs_g,
gen_output = gen_output,
train = tf.group(update_loss, incr_global_step, gen_train),
global_step = global_step,
learning_rate = learning_rate
)


def SRResnet(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]):
def SRResnet(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(4)]):
# Define the container of the parameter
Network = collections.namedtuple('Network', 'content_loss, gen_grads_and_vars, gen_output, train, global_step, \
learning_rate')
tower_grads = []
tower_outputs = []
with tf.device('/gpu:0'):
with tf.device('/cpu:0'):
split_inputs = tf.split(inputs, len(devices), axis=0)
split_targets = tf.split(targets, len(devices), axis=0)
# Define the learning rate and global step
Expand Down Expand Up @@ -621,7 +573,6 @@ def save_images(fetches, FLAGS, step=None):
f.write(contents)
filesets.append(fileset)
return filesets

def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Expand Down Expand Up @@ -665,3 +616,5 @@ def average_gradients(tower_grads):





0 comments on commit a947f8d

Please sign in to comment.