diff --git a/lib/model.py b/lib/model.py index ca115c5..07ee0a3 100644 --- a/lib/model.py +++ b/lib/model.py @@ -25,7 +25,7 @@ def data_loader(FLAGS): raise ValueError('Input directory not found') image_list_LR = os.listdir(FLAGS.input_dir_LR) - image_list_LR = [_ for _ in image_list_LR if _.endswith('.png')] + image_list_LR = [_ for _ in image_list_LR if _.endswith('.jpg')] if len(image_list_LR)==0: raise Exception('No png files in the input directory') @@ -48,8 +48,8 @@ def data_loader(FLAGS): reader = tf.WholeFileReader(name='image_reader') image_LR = tf.read_file(output[0]) image_HR = tf.read_file(output[1]) - input_image_LR = tf.image.decode_png(image_LR, channels=3) - input_image_HR = tf.image.decode_png(image_HR, channels=3) + input_image_LR = tf.image.decode_jpeg(image_LR, channels=3) + input_image_HR = tf.image.decode_jpeg(image_HR, channels=3) input_image_LR = tf.image.convert_image_dtype(input_image_LR, dtype=tf.float32) input_image_HR = tf.image.convert_image_dtype(input_image_HR, dtype=tf.float32) @@ -73,9 +73,9 @@ def data_loader(FLAGS): # Set the shape of the input image. the target will have 4X size input_size = tf.shape(inputs) target_size = tf.shape(targets) - offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)), + offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)), dtype=tf.int32) - offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)), + offset_h = tf.cast(tf.floor(tf.random_uniform([], 0 , tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)), dtype=tf.int32) if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet': @@ -148,8 +148,8 @@ def test_data_loader(FLAGS): raise ValueError('Input directory not found') image_list_LR_temp = os.listdir(FLAGS.input_dir_LR) - image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png'] - image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png'] + image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'jpg'] + image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'jpg'] # Read in and preprocess the images def preprocess_test(name, mode): @@ -192,7 +192,7 @@ def inference_data_loader(FLAGS): raise ValueError('Input directory not found') image_list_LR_temp = os.listdir(FLAGS.input_dir_LR) - image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png'] + image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'jpg'] # Read in and preprocess the images def preprocess_test(name): @@ -227,10 +227,10 @@ def generator(gen_inputs, gen_output_channels, reuse=False, FLAGS=None): # The Bx residual blocks def residual_block(inputs, output_channel, stride, scope): with tf.variable_scope(scope): - net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1') + net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1',norm=FLAGS.w_norm) net = batchnorm(net, FLAGS.is_training) net = prelu_tf(net) - net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2') + net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2',norm=FLAGS.w_norm) net = batchnorm(net, FLAGS.is_training) net = net + inputs @@ -240,7 +240,7 @@ def residual_block(inputs, output_channel, stride, scope): with tf.variable_scope('generator_unit', reuse=reuse): # The input layer with tf.variable_scope('input_stage'): - net = conv2(gen_inputs, 9, 64, 1, scope='conv') + net = conv2(gen_inputs, 9, 64, 1, scope='conv',norm=FLAGS.w_norm) net = prelu_tf(net) stage1_output = net @@ -251,46 +251,49 @@ def residual_block(inputs, output_channel, stride, scope): net = residual_block(net, 64, 1, name_scope) with tf.variable_scope('resblock_output'): - net = conv2(net, 3, 64, 1, use_bias=False, scope='conv') + net = conv2(net, 3, 64, 1, use_bias=False, scope='conv',norm=FLAGS.w_norm) net = batchnorm(net, FLAGS.is_training) - + + if FLAGS.attention: + net = attention(net,64,reuse=reuse,FLAGS) + net = net + stage1_output with tf.variable_scope('subpixelconv_stage1'): - net = conv2(net, 3, 256, 1, scope='conv') + net = conv2(net, 3, 256, 1, scope='conv',norm=FLAGS.w_norm) net = pixelShuffler(net, scale=2) net = prelu_tf(net) with tf.variable_scope('subpixelconv_stage2'): - net = conv2(net, 3, 256, 1, scope='conv') + net = conv2(net, 3, 256, 1, scope='conv',norm=FLAGS.w_norm) net = pixelShuffler(net, scale=2) net = prelu_tf(net) with tf.variable_scope('output_stage'): - net = conv2(net, 9, gen_output_channels, 1, scope='conv') + net = conv2(net, 9, gen_output_channels, 1, scope='conv',norm=FLAGS.w_norm) return net # Definition of the discriminator -def discriminator(dis_inputs, FLAGS=None): +def discriminator(dis_inputs, FLAGS=None, reuse=False): if FLAGS is None: raise ValueError('No FLAGS is provided for generator') # Define the discriminator block def discriminator_block(inputs, output_channel, kernel_size, stride, scope): with tf.variable_scope(scope): - net = conv2(inputs, kernel_size, output_channel, stride, use_bias=False, scope='conv1') + net = conv2(inputs, kernel_size, output_channel, stride, use_bias=False, scope='conv1',norm=FLAGS.w_norm) net = batchnorm(net, FLAGS.is_training) net = lrelu(net, 0.2) return net with tf.device('/gpu:0'): - with tf.variable_scope('discriminator_unit'): + with tf.variable_scope('discriminator_unit',reuse=reuse): # The input layer with tf.variable_scope('input_stage'): - net = conv2(dis_inputs, 3, 64, 1, scope='conv') + net = conv2(dis_inputs, 3, 64, 1, scope='conv',norm=FLAGS.w_norm) net = lrelu(net, 0.2) # The discriminator block part @@ -305,7 +308,10 @@ def discriminator_block(inputs, output_channel, kernel_size, stride, scope): # block 4 net = discriminator_block(net, 256, 3, 1, 'disblock_4') - + + if FLAGS.attention: + net = attention(net,256,reuse=reuse,FLAGS) + # block 5 net = discriminator_block(net, 256, 3, 2, 'disblock_5') @@ -318,23 +324,41 @@ def discriminator_block(inputs, output_channel, kernel_size, stride, scope): # The dense layer 1 with tf.variable_scope('dense_layer_1'): net = slim.flatten(net) - net = denselayer(net, 1024) + net = denselayer(net, 1024,norm=FLAGS.w_norm) net = lrelu(net, 0.2) # The dense layer 2 with tf.variable_scope('dense_layer_2'): - net = denselayer(net, 1) + net = denselayer(net, 1,norm=FLAGS.w_norm) net = tf.nn.sigmoid(net) return net +def attention(x, ch, scope='attention', reuse=False,FLAGS=None): + with tf.variable_scope(scope, reuse=reuse): + f = conv2(x, ch // 8, kernel=1, stride=1, scope='f_conv',norm=FLAGS.w_norm) # [bs, h, w, c'] + g = conv2(x, ch // 8, kernel=1, stride=1, scope='g_conv',norm=FLAGS.w_norm) # [bs, h, w, c'] + h = conv2(x, ch, kernel=1, stride=1, scope='h_conv',norm=FLAGS.w_norm) # [bs, h, w, c] + + # N = h * w + s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] + + beta = tf.nn.softmax(s, axis=-1) # attention map + + o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] + gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) + + o = tf.reshape(o, shape=x.shape) # [bs, h, w, C] + x = gamma * o + x + + return x def VGG19_slim(input, type, reuse, scope): # Define the feature to extract according to the type of perceptual if type == 'VGG54': - target_layer = scope + 'vgg_19/conv5/conv5_4' + target_layer = 'vgg_19/conv5/conv5_4' elif type == 'VGG22': - target_layer = scope + 'vgg_19/conv2/conv2_2' + target_layer = 'vgg_19/conv2/conv2_2' else: raise NotImplementedError('Unknown perceptual type') _, output = vgg_19(input, is_training=False, reuse=reuse) @@ -344,121 +368,170 @@ def VGG19_slim(input, type, reuse, scope): # Define the whole network architecture -def SRGAN(inputs, targets, FLAGS): +def SRGAN(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]): # Define the container of the parameter Network = collections.namedtuple('Network', 'discrim_real_output, discrim_fake_output, discrim_loss, \ discrim_grads_and_vars, adversarial_loss, content_loss, gen_grads_and_vars, gen_output, train, global_step, \ learning_rate') - # Build the generator part - with tf.variable_scope('generator'): - output_channel = targets.get_shape().as_list()[-1] - gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS) - gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) - - # Build the fake discriminator - with tf.name_scope('fake_discriminator'): - with tf.variable_scope('discriminator', reuse=False): - discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS) - - # Build the real discriminator - with tf.name_scope('real_discriminator'): - with tf.variable_scope('discriminator', reuse=True): - discrim_real_output = discriminator(targets, FLAGS=FLAGS) - - # Use the VGG54 feature - if FLAGS.perceptual_mode == 'VGG54': - with tf.name_scope('vgg19_1') as scope: - extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) - with tf.name_scope('vgg19_2') as scope: - extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) - - # Use the VGG22 feature - elif FLAGS.perceptual_mode == 'VGG22': - with tf.name_scope('vgg19_1') as scope: - extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) - with tf.name_scope('vgg19_2') as scope: - extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) - - # Use MSE loss directly - elif FLAGS.perceptual_mode == 'MSE': - extracted_feature_gen = gen_output - extracted_feature_target = targets + #generator tower lists + tower_grads = [] + tower_outputs = [] + #discriminator tower lists + tower_grads_d = [] + tower_outputs_real_d = [] + tower_outputs_fake_d = [] + tower_discriminator_global = [] - else: - raise NotImplementedError('Unknown perceptual type!!') - - # Calculating the generator loss - with tf.variable_scope('generator_loss'): - # Content loss - with tf.variable_scope('content_loss'): - # Compute the euclidean distance between the two features - diff = extracted_feature_gen - extracted_feature_target - if FLAGS.perceptual_mode == 'MSE': - content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) - else: - content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) - - with tf.variable_scope('adversarial_loss'): - adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS)) - - gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss - print(adversarial_loss.get_shape()) - print(content_loss.get_shape()) - - # Calculating the discriminator loss - with tf.variable_scope('discriminator_loss'): - discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS) - discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS) - - discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss)) - - # Define the learning rate and global step - with tf.variable_scope('get_learning_rate_and_global_step'): - global_step = tf.contrib.framework.get_or_create_global_step() - learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair) - incr_global_step = tf.assign(global_step, global_step + 1) - - with tf.variable_scope('dicriminator_train'): - discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') - discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) - discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars) - discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars) - - with tf.variable_scope('generator_train'): - # Need to wait discriminator to perform train step - with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)): - gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') - gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) - gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) - gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) + with tf.device('/gpu:0'): + split_inputs = tf.split(inputs, len(devices), axis=0) + split_targets = tf.split(targets, len(devices), axis=0) + # Define the learning rate and global step + with tf.variable_scope('get_learning_rate_and_global_step'): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, + staircase=FLAGS.stair) + incr_global_step = tf.assign(global_step, global_step + 1) + + + with tf.variable_scope(tf.get_variable_scope()) as scope: + for i, (inputs, targets, dev) in enumerate(zip(split_inputs, split_targets, devices)): + with tf.device(dev): + with tf.name_scope('tower%d'%i): + # Build the generator part + with tf.variable_scope('generator'): + output_channel = targets.get_shape().as_list()[-1] + gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS) + gen_output.set_shape([FLAGS.batch_size/len(devices) , FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) + + tower_outputs.append(gen_output) + + # Build the fake discriminator + with tf.name_scope('fake_discriminator'): + with tf.variable_scope('discriminator', reuse=False): + discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS) + tower_outputs_fake_d.append(discrim_fake_output) + + # Build the real discriminator + with tf.name_scope('real_discriminator'): + with tf.variable_scope('discriminator', reuse=True): + discrim_real_output = discriminator(targets, FLAGS=FLAGS) + tower_outputs_real_d.append(discrim_real_output) + + # Use the VGG54 feature + if FLAGS.perceptual_mode == 'VGG54': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + # Use the VGG22 feature + elif FLAGS.perceptual_mode == 'VGG22': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + # Use MSE loss directly + elif FLAGS.perceptual_mode == 'MSE': + extracted_feature_gen = gen_output + extracted_feature_target = targets + + else: + raise NotImplementedError('Unknown perceptual type!!') + + # Calculating the generator loss + with tf.variable_scope('generator_loss'): + # Content loss + with tf.variable_scope('content_loss'): + # Compute the euclidean distance between the two features + diff = extracted_feature_gen - extracted_feature_target + if FLAGS.perceptual_mode == 'MSE': + content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + else: + content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + + with tf.variable_scope('adversarial_loss'): + adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS)) + + gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss + print(adversarial_loss.get_shape()) + print(content_loss.get_shape()) + + # Calculating the discriminator loss + with tf.variable_scope('discriminator_loss'): + discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS) + discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS) + + discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss)) + + # Define the learning rate and global step + with tf.variable_scope('get_learning_rate_and_global_step'): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair) + incr_global_step = tf.assign(global_step, global_step + 1) + + #scope.reuse_variables() + with tf.variable_scope('dicriminator_train',reuse=tf.AUTO_REUSE): + discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') + discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars) + #discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars) + #scope.reuse_variables() + tower_grads_d.append(discrim_grads_and_vars) + + scope.reuse_variables() + with tf.variable_scope('generator_train'): + # Need to wait discriminator to perform train step + with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) + #gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) + tower_grads.append(gen_grads_and_vars) #[ToDo] If we do not use moving average on loss?? exp_averager = tf.train.ExponentialMovingAverage(decay=0.99) update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss]) + + + #discriminator aggregation + avg_grads_d = average_gradients(tower_grads_d) + discrim_train = discrim_optimizer.apply_gradients(avg_grads_d) + + all_outputs_real_d = tf.concat(tower_outputs_real_d, axis=0) + all_outputs_fake_d = tf.concat(tower_outputs_fake_d, axis=0) + + with tf.control_dependencies([discrim_train]): + #generator aggregation + avg_grads = average_gradients(tower_grads) + gen_train = gen_optimizer.apply_gradients(avg_grads) + all_outputs_g = tf.concat(tower_outputs, axis=0) + + return Network( - discrim_real_output = discrim_real_output, - discrim_fake_output = discrim_fake_output, + discrim_real_output = all_outputs_real_d, + discrim_fake_output = all_outputs_fake_d, discrim_loss = exp_averager.average(discrim_loss), discrim_grads_and_vars = discrim_grads_and_vars, adversarial_loss = exp_averager.average(adversarial_loss), content_loss = exp_averager.average(content_loss), gen_grads_and_vars = gen_grads_and_vars, - gen_output = gen_output, + gen_output = all_outputs_g, train = tf.group(update_loss, incr_global_step, gen_train), global_step = global_step, learning_rate = learning_rate ) -def SRResnet(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(4)]): +def SRResnet(inputs, targets, FLAGS, devices = ['/gpu:%d'%i for i in range(8)]): # Define the container of the parameter Network = collections.namedtuple('Network', 'content_loss, gen_grads_and_vars, gen_output, train, global_step, \ learning_rate') tower_grads = [] tower_outputs = [] - with tf.device('/cpu:0'): + with tf.device('/gpu:0'): split_inputs = tf.split(inputs, len(devices), axis=0) split_targets = tf.split(targets, len(devices), axis=0) # Define the learning rate and global step @@ -552,7 +625,7 @@ def save_images(fetches, FLAGS, step=None): if FLAGS.mode == 'inference': kind = "outputs" - filename = name + ".png" + filename = name + ".jpg" if step is not None: filename = "%08d-%s" % (step, filename) fileset[kind] = filename @@ -562,8 +635,13 @@ def save_images(fetches, FLAGS, step=None): f.write(contents) filesets.append(fileset) else: + psnr = fetches['psnr'] + ssim = fetches['SSIM'] for kind in ["inputs", "outputs", "targets"]: - filename = name + "-" + kind + ".png" + if kind == "outputs": + filename = name + "-" + kind + "(PSNR: " +str(psnr)+" and SSIM: "+str(ssim)+ ").jpg" + else: + filename = name + "-" + kind + ".jpg" if step is not None: filename = "%08d-%s" % (step, filename) fileset[kind] = filename @@ -618,3 +696,4 @@ def average_gradients(tower_grads): +