diff --git a/lib/model.py b/lib/model.py index d534fc9..ca115c5 100644 --- a/lib/model.py +++ b/lib/model.py @@ -332,9 +332,9 @@ def discriminator_block(inputs, output_channel, kernel_size, stride, scope): def VGG19_slim(input, type, reuse, scope): # Define the feature to extract according to the type of perceptual if type == 'VGG54': - target_layer = 'vgg_19/conv5/conv5_4' + target_layer = scope + 'vgg_19/conv5/conv5_4' elif type == 'VGG22': - target_layer = 'vgg_19/conv2/conv2_2' + target_layer = scope + 'vgg_19/conv2/conv2_2' else: raise NotImplementedError('Unknown perceptual type') _, output = vgg_19(input, is_training=False, reuse=reuse) @@ -344,169 +344,121 @@ def VGG19_slim(input, type, reuse, scope): # Define the whole network architecture -def SRGAN(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]): +def SRGAN(inputs, targets, FLAGS): # Define the container of the parameter Network = collections.namedtuple('Network', 'discrim_real_output, discrim_fake_output, discrim_loss, \ discrim_grads_and_vars, adversarial_loss, content_loss, gen_grads_and_vars, gen_output, train, global_step, \ learning_rate') - #generator tower lists - tower_grads = [] - tower_outputs = [] - #discriminator tower lists - tower_grads_d = [] - tower_outputs_real_d = [] - tower_outputs_fake_d = [] - tower_discriminator_global = [] - - with tf.device('/gpu:0'): - split_inputs = tf.split(inputs, len(devices), axis=0) - split_targets = tf.split(targets, len(devices), axis=0) - # Define the learning rate and global step - with tf.variable_scope('get_learning_rate_and_global_step'): - global_step = tf.contrib.framework.get_or_create_global_step() - learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, - staircase=FLAGS.stair) - incr_global_step = tf.assign(global_step, global_step + 1) - - - with tf.variable_scope(tf.get_variable_scope()) as scope: - for i, (inputs, targets, dev) in enumerate(zip(split_inputs, split_targets, devices)): - with tf.device(dev): - with tf.name_scope('tower%d'%i): - # Build the generator part - with tf.variable_scope('generator'): - output_channel = targets.get_shape().as_list()[-1] - gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS) - gen_output.set_shape([FLAGS.batch_size/len(devices) , FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) - - tower_outputs.append(gen_output) - - # Build the fake discriminator - with tf.name_scope('fake_discriminator'): - with tf.variable_scope('discriminator', reuse=False): - discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS) - tower_outputs_fake_d.append(discrim_fake_output) + # Build the generator part + with tf.variable_scope('generator'): + output_channel = targets.get_shape().as_list()[-1] + gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS) + gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) + + # Build the fake discriminator + with tf.name_scope('fake_discriminator'): + with tf.variable_scope('discriminator', reuse=False): + discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS) + + # Build the real discriminator + with tf.name_scope('real_discriminator'): + with tf.variable_scope('discriminator', reuse=True): + discrim_real_output = discriminator(targets, FLAGS=FLAGS) + + # Use the VGG54 feature + if FLAGS.perceptual_mode == 'VGG54': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + # Use the VGG22 feature + elif FLAGS.perceptual_mode == 'VGG22': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + # Use MSE loss directly + elif FLAGS.perceptual_mode == 'MSE': + extracted_feature_gen = gen_output + extracted_feature_target = targets - # Build the real discriminator - with tf.name_scope('real_discriminator'): - with tf.variable_scope('discriminator', reuse=True): - discrim_real_output = discriminator(targets, FLAGS=FLAGS) - tower_outputs_real_d.append(discrim_real_output) - - # Use the VGG54 feature - if FLAGS.perceptual_mode == 'VGG54': - with tf.name_scope('vgg19_1') as scope: - extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) - with tf.name_scope('vgg19_2') as scope: - extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) - - # Use the VGG22 feature - elif FLAGS.perceptual_mode == 'VGG22': - with tf.name_scope('vgg19_1') as scope: - extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) - with tf.name_scope('vgg19_2') as scope: - extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) - - # Use MSE loss directly - elif FLAGS.perceptual_mode == 'MSE': - extracted_feature_gen = gen_output - extracted_feature_target = targets - - else: - raise NotImplementedError('Unknown perceptual type!!') - - # Calculating the generator loss - with tf.variable_scope('generator_loss'): - # Content loss - with tf.variable_scope('content_loss'): - # Compute the euclidean distance between the two features - diff = extracted_feature_gen - extracted_feature_target - if FLAGS.perceptual_mode == 'MSE': - content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) - else: - content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) - - with tf.variable_scope('adversarial_loss'): - adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS)) - - gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss - print(adversarial_loss.get_shape()) - print(content_loss.get_shape()) - - # Calculating the discriminator loss - with tf.variable_scope('discriminator_loss'): - discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS) - discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS) - - discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss)) - - # Define the learning rate and global step - with tf.variable_scope('get_learning_rate_and_global_step'): - global_step = tf.contrib.framework.get_or_create_global_step() - learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair) - incr_global_step = tf.assign(global_step, global_step + 1) - - with tf.variable_scope('dicriminator_train',reuse=tf.AUTO_REUSE): - discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') - discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) - discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars) - #discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars) - #scope.reuse_variables() - tower_grads_d.append(discrim_grads_and_vars) - - scope.reuse_variables() - with tf.variable_scope('generator_train'): - # Need to wait discriminator to perform train step - with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): - gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') - gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) - gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) - #gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) - tower_grads.append(gen_grads_and_vars) + else: + raise NotImplementedError('Unknown perceptual type!!') + + # Calculating the generator loss + with tf.variable_scope('generator_loss'): + # Content loss + with tf.variable_scope('content_loss'): + # Compute the euclidean distance between the two features + diff = extracted_feature_gen - extracted_feature_target + if FLAGS.perceptual_mode == 'MSE': + content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + else: + content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + + with tf.variable_scope('adversarial_loss'): + adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS)) + + gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss + print(adversarial_loss.get_shape()) + print(content_loss.get_shape()) + + # Calculating the discriminator loss + with tf.variable_scope('discriminator_loss'): + discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS) + discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS) + + discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss)) + + # Define the learning rate and global step + with tf.variable_scope('get_learning_rate_and_global_step'): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair) + incr_global_step = tf.assign(global_step, global_step + 1) + + with tf.variable_scope('dicriminator_train'): + discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') + discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars) + discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars) + + with tf.variable_scope('generator_train'): + # Need to wait discriminator to perform train step + with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) + gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) #[ToDo] If we do not use moving average on loss?? exp_averager = tf.train.ExponentialMovingAverage(decay=0.99) update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss]) - - - #discriminator aggregation - avg_grads_d = average_gradients(tower_grads_d) - discrim_train = discrim_optimizer.apply_gradients(avg_grads_d) - - all_outputs_real_d = tf.concat(tower_outputs_real_d, axis=0) - all_outputs_fake_d = tf.concat(tower_outputs_fake_d, axis=0) - - with tf.control_dependencies([discrim_train]): - #generator aggregation - avg_grads = average_gradients(tower_grads) - gen_train = gen_optimizer.apply_gradients(avg_grads) - all_outputs_g = tf.concat(tower_outputs, axis=0) - - return Network( - discrim_real_output = all_outputs_real_d, - discrim_fake_output = all_outputs_fake_d, + discrim_real_output = discrim_real_output, + discrim_fake_output = discrim_fake_output, discrim_loss = exp_averager.average(discrim_loss), discrim_grads_and_vars = discrim_grads_and_vars, adversarial_loss = exp_averager.average(adversarial_loss), content_loss = exp_averager.average(content_loss), gen_grads_and_vars = gen_grads_and_vars, - gen_output = all_outputs_g, + gen_output = gen_output, train = tf.group(update_loss, incr_global_step, gen_train), global_step = global_step, learning_rate = learning_rate ) -def SRResnet(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]): +def SRResnet(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(4)]): # Define the container of the parameter Network = collections.namedtuple('Network', 'content_loss, gen_grads_and_vars, gen_output, train, global_step, \ learning_rate') tower_grads = [] tower_outputs = [] - with tf.device('/gpu:0'): + with tf.device('/cpu:0'): split_inputs = tf.split(inputs, len(devices), axis=0) split_targets = tf.split(targets, len(devices), axis=0) # Define the learning rate and global step @@ -621,7 +573,6 @@ def save_images(fetches, FLAGS, step=None): f.write(contents) filesets.append(fileset) return filesets - def average_gradients(tower_grads): """Calculate the average gradient for each shared variable across all towers. Note that this function provides a synchronization point across all towers. @@ -665,3 +616,5 @@ def average_gradients(tower_grads): + +