From ecbfb4f8701f2281fe4b876a90f2486e59d4e69a Mon Sep 17 00:00:00 2001 From: Horatio Date: Sun, 16 Jul 2017 20:49:35 +0800 Subject: [PATCH] Initial project --- .gitignore | 2 + .idea/SRGAN_tensorflow.iml | 12 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/workspace.xml | 132 ++++++ lib/__init__.py | 0 lib/model.py | 878 +++++++++++++++++++++++++++++++++++++ lib/model_dense.py | 95 ++++ lib/ops.py | 231 ++++++++++ main.py | 343 +++++++++++++++ test_SRGAN.sh | 15 + tool/__init__.py | 0 tool/convertPNG.py | 0 tool/resizeImage.py | 39 ++ train_SRGAN.sh | 29 ++ train_SRResnet.sh | 28 ++ 16 files changed, 1816 insertions(+) create mode 100644 .gitignore create mode 100644 .idea/SRGAN_tensorflow.iml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/workspace.xml create mode 100644 lib/__init__.py create mode 100644 lib/model.py create mode 100644 lib/model_dense.py create mode 100644 lib/ops.py create mode 100644 main.py create mode 100644 test_SRGAN.sh create mode 100644 tool/__init__.py create mode 100644 tool/convertPNG.py create mode 100644 tool/resizeImage.py create mode 100644 train_SRGAN.sh create mode 100644 train_SRResnet.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ddbcc37 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.pyc +*.ckpt \ No newline at end of file diff --git a/.idea/SRGAN_tensorflow.iml b/.idea/SRGAN_tensorflow.iml new file mode 100644 index 0000000..6f63a63 --- /dev/null +++ b/.idea/SRGAN_tensorflow.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..795374c --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..4e2467f --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000..a56e092 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,132 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1500209151200 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/model.py b/lib/model.py new file mode 100644 index 0000000..c10435f --- /dev/null +++ b/lib/model.py @@ -0,0 +1,878 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from lib.ops import * +from lib.model_dense import generatorDense +import collections +import os +import math +from tensorflow.contrib.keras.api.keras.applications.vgg19 import VGG19 +from tensorflow.contrib.keras.api.keras.models import Model +import scipy.misc as sic +import numpy as np + + +# Define the data loader +def data_loader(FLAGS): + # Define the returned data batches + Data = collections.namedtuple('Data', 'paths, inputs, targets, image_count, steps_per_epoch') + + #Check the input directory + if FLAGS.input_dir == 'None': + raise ValueError('Input directory is not provided') + + if not os.path.exists(FLAGS.input_dir): + raise ValueError('Input directory not found') + + image_list = os.listdir(FLAGS.input_dir) + image_list = [_ for _ in image_list if _.endswith('.png')] + if len(image_list)==0: + raise Exception('No png files in the input directory') + + image_list = sorted(image_list) + image_list = [os.path.join(FLAGS.input_dir, _) for _ in image_list] + + with tf.variable_scope('load_image'): + # define the image list queue + image_list_queue = tf.train.string_input_producer(image_list, shuffle=(FLAGS.mode == 'Train'), capacity=FLAGS.name_queue_capacity) + print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train')) + + # Reading and decode the images + reader = tf.WholeFileReader(name='image_reader') + paths, image = reader.read(image_list_queue) + input_image = tf.image.decode_png(image) + input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32) + + assertion = tf.assert_equal(tf.shape(input_image)[2], 3, message="image does not have 3 channels") + with tf.control_dependencies([assertion]): + input_image = tf.identity(input_image) + + # Break apart the paired images and normalize to [-1, 1] + width = tf.shape(input_image)[1] + a_image = preprocess(input_image[:, :(width // 2), :]) + b_image = preprocess(input_image[:, (width // 2):, :]) + + + if FLAGS.which_direction == "AtoB": + inputs, targets = [a_image, b_image] + elif FLAGS.which_direction == "BtoA": + inputs, targets = [b_image, a_image] + else: + raise Exception("invalid direction") + + # The data augmentation part + with tf.variable_scope('data_preprocessing'): + with tf.variable_scope('random_crop'): + # Check whether perform crop + if (FLAGS.random_crop is True) and (FLAGS.crop_size < FLAGS.train_image_width and + FLAGS.crop_size < FLAGS.train_image_height) and FLAGS.mode == 'train': + print('[Config] Use random crop') + inputs.set_shape([FLAGS.train_image_height, FLAGS.train_image_width, 3]) + targets.set_shape([FLAGS.train_image_height, FLAGS.train_image_width, 3]) + offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, FLAGS.train_image_width - FLAGS.crop_size)), dtype=tf.int32) + offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, FLAGS.train_image_height - FLAGS.crop_size)), dtype=tf.int32) + + inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size, FLAGS.crop_size) + targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w, FLAGS.crop_size, FLAGS.crop_size) + # Do not perform crop + else: + inputs = tf.identity(inputs) + targets = tf.identity(targets) + + with tf.variable_scope('random_flip'): + # Check for random flip: + if (FLAGS.flip is True) and (FLAGS.mode == 'train'): + print('[Config] Use random flip') + # Produce the decision of random flip + decision = tf.random_uniform([], 0, 1, dtype=tf.float32) + + input_images = random_flip(inputs, decision) + target_images = random_flip(targets, decision) + else: + input_images = tf.identity(inputs) + target_images = tf.identity(targets) + + if FLAGS.mode == 'train': + paths_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([paths, input_images, target_images], + batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity, + min_after_dequeue=512, num_threads=20) + else: + paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], + batch_size=FLAGS.batch_size, num_threads=20, allow_smaller_final_batch=True) + + steps_per_epoch = int(math.ceil(len(image_list) / FLAGS.batch_size)) + + return Data( + paths=paths_batch, + inputs=inputs_batch, + targets=targets_batch, + image_count=len(image_list), + steps_per_epoch=steps_per_epoch + ) + + +def data_loader2(FLAGS): + with tf.device('/cpu:0'): + # Define the returned data batches + Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets, image_count, steps_per_epoch') + + #Check the input directory + if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'): + raise ValueError('Input directory is not provided') + + if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)): + raise ValueError('Input directory not found') + + image_list_LR = os.listdir(FLAGS.input_dir_LR) + image_list_LR = [_ for _ in image_list_LR if _.endswith('.png')] + if len(image_list_LR)==0: + raise Exception('No png files in the input directory') + + image_list_LR_temp = sorted(image_list_LR) + image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp] + image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp] + + image_list_LR_tensor = tf.convert_to_tensor(image_list_LR, dtype=tf.string) + image_list_HR_tensor = tf.convert_to_tensor(image_list_HR, dtype=tf.string) + + with tf.variable_scope('load_image'): + # define the image list queue + # image_list_LR_queue = tf.train.string_input_producer(image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity) + # image_list_HR_queue = tf.train.string_input_producer(image_list_HR, shuffle=False, capacity=FLAGS.name_queue_capacity) + #print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train')) + output = tf.train.slice_input_producer([image_list_LR_tensor, image_list_HR_tensor], + shuffle=False, capacity=FLAGS.name_queue_capacity) + + # Reading and decode the images + reader = tf.WholeFileReader(name='image_reader') + image_LR = tf.read_file(output[0]) + image_HR = tf.read_file(output[1]) + input_image_LR = tf.image.decode_png(image_LR, channels=3) + input_image_HR = tf.image.decode_png(image_HR, channels=3) + input_image_LR = tf.image.convert_image_dtype(input_image_LR, dtype=tf.float32) + input_image_HR = tf.image.convert_image_dtype(input_image_HR, dtype=tf.float32) + + assertion = tf.assert_equal(tf.shape(input_image_LR)[2], 3, message="image does not have 3 channels") + with tf.control_dependencies([assertion]): + input_image_LR = tf.identity(input_image_LR) + input_image_HR = tf.identity(input_image_HR) + + # Normalize the low resolution image to [0, 1], high resolution to [-1, 1] + a_image = preprocessLR(input_image_LR) + b_image = preprocess(input_image_HR) + + inputs, targets = [a_image, b_image] + + # The data augmentation part + with tf.name_scope('data_preprocessing'): + with tf.name_scope('random_crop'): + # Check whether perform crop + if (FLAGS.random_crop is True) and FLAGS.mode == 'train': + print('[Config] Use random crop') + # Set the shape of the input image. the target will have 4X size + input_size = tf.shape(inputs) + target_size = tf.shape(targets) + offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)), + dtype=tf.int32) + offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)), + dtype=tf.int32) + + if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet': + inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size, + FLAGS.crop_size) + targets = tf.image.crop_to_bounding_box(targets, offset_h*4, offset_w*4, FLAGS.crop_size*4, + FLAGS.crop_size*4) + elif FLAGS.task == 'denoise': + inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size, + FLAGS.crop_size) + targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w, + FLAGS.crop_size, FLAGS.crop_size) + # Do not perform crop + else: + inputs = tf.identity(inputs) + targets = tf.identity(targets) + + with tf.variable_scope('random_flip'): + # Check for random flip: + if (FLAGS.flip is True) and (FLAGS.mode == 'train'): + print('[Config] Use random flip') + # Produce the decision of random flip + decision = tf.random_uniform([], 0, 1, dtype=tf.float32) + + input_images = random_flip(inputs, decision) + target_images = random_flip(targets, decision) + else: + input_images = tf.identity(inputs) + target_images = tf.identity(targets) + + if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet': + input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3]) + target_images.set_shape([FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) + elif FLAGS.task == 'denoise': + input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3]) + target_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3]) + + if FLAGS.mode == 'train': + paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([output[0], output[1], input_images, target_images], + batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity+4*FLAGS.batch_size, + min_after_dequeue=FLAGS.image_queue_capacity, num_threads=FLAGS.queue_thread) + else: + paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.batch([output[0], output[1], input_images, target_images], + batch_size=FLAGS.batch_size, num_threads=FLAGS.queue_thread, allow_smaller_final_batch=True) + + steps_per_epoch = int(math.ceil(len(image_list_LR) / FLAGS.batch_size)) + if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet': + inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3]) + targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) + elif FLAGS.task == 'denoise': + inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3]) + targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3]) + return Data( + paths_LR=paths_LR_batch, + paths_HR=paths_HR_batch, + inputs=inputs_batch, + targets=targets_batch, + image_count=len(image_list_LR), + steps_per_epoch=steps_per_epoch + ) + + +# The test data loader. Allow input image with different size +def test_data_loader(FLAGS): + # Get the image name list + if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'): + raise ValueError('Input directory is not provided') + + if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)): + raise ValueError('Input directory not found') + + image_list_LR_temp = os.listdir(FLAGS.input_dir_LR) + image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png'] + image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png'] + + # Read in and preprocess the images + def preprocess_test(name, mode): + im = sic.imread(name).astype(np.float32) + # check grayscale image + if im.shape[-1] != 3: + h, w = im.shape + temp = np.empty((h, w, 3), dtype=np.uint8) + temp[:, :, :] = im[:, :, np.newaxis] + im = temp.copy() + if mode == 'LR': + im = im / np.max(im) + elif mode == 'HR': + im = im / np.max(im) + im = im * 2 - 1 + + return im + + image_LR = [preprocess_test(_, 'LR') for _ in image_list_LR] + image_HR = [preprocess_test(_, 'HR') for _ in image_list_HR] + + # Push path and image into a list + Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets') + + return Data( + paths_LR = image_list_LR, + paths_HR = image_list_HR, + inputs = image_LR, + targets = image_HR + ) + + +# Definition of the generator +def generator(gen_inputs, gen_output_channels, reuse=False, FLAGS=None): + # Check the flag + if FLAGS is None: + raise ValueError('No FLAGS is provided for generator') + + # The Bx residual blocks + def residual_block(inputs, output_channel, stride, scope): + with tf.variable_scope(scope): + net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1') + net = batchnorm(net, FLAGS.is_training) + net = prelu_tf(net) + net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2') + net = batchnorm(net, FLAGS.is_training) + net = net + inputs + + return net + + + with tf.variable_scope('generator_unit', reuse=reuse): + # The input layer + with tf.variable_scope('input_stage'): + net = conv2(gen_inputs, 9, 64, 1, scope='conv') + net = prelu_tf(net) + + stage1_output = net + + # The residual block parts + for i in range(1, FLAGS.num_resblock+1 , 1): + name_scope = 'resblock_%d'%(i) + net = residual_block(net, 64, 1, name_scope) + + with tf.variable_scope('resblock_output'): + net = conv2(net, 3, 64, 1, use_bias=False, scope='conv') + net = batchnorm(net, FLAGS.is_training) + + net = net + stage1_output + + with tf.variable_scope('subpixelconv_stage1'): + net = conv2(net, 3, 256, 1, scope='conv') + net = pixelShuffler(net, scale=2) + net = prelu_tf(net) + + with tf.variable_scope('subpixelconv_stage2'): + net = conv2(net, 3, 256, 1, scope='conv') + net = pixelShuffler(net, scale=2) + net = prelu_tf(net) + + with tf.variable_scope('output_stage'): + net = conv2(net, 9, gen_output_channels, 1, scope='conv') + + return net + + +# Define the generator for the denoise task +def generator_denoise(gen_inputs, gen_output_channels, reuse=False, FLAGS=None): + # Check the flag + if FLAGS is None: + raise ValueError('No FLAGS is provided for generator') + + # The Bx residual blocks + def residual_block(inputs, output_channel, stride, scope): + with tf.variable_scope(scope): + net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1') + net = batchnorm(net, FLAGS.is_training) + net = prelu_tf(net) + net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2') + net = batchnorm(net, FLAGS.is_training) + net = net + inputs + + return net + + # [optional] Put network on different GPU + with tf.device('/gpu:0'): + with tf.variable_scope('generator_unit', reuse=reuse): + # The input layer + with tf.variable_scope('input_stage'): + net = conv2(gen_inputs, 9, 64, 1, scope='conv') + net = prelu_tf(net) + + stage1_output = net + + # The residual block parts + for i in range(1, FLAGS.num_resblock+1 , 1): + name_scope = 'resblock_%d'%(i) + net = residual_block(net, 64, 1, name_scope) + + with tf.variable_scope('resblock_output'): + net = conv2(net, 3, 64, 1, use_bias=False, scope='conv') + net = batchnorm(net, FLAGS.is_training) + + net = net + stage1_output + + with tf.device('/gpu:0'): + with tf.variable_scope('refine_stage1'): + net = conv2(net, 3, 256, 1, scope='conv') + net = batchnorm(net, FLAGS.is_training) + net = prelu_tf(net) + + with tf.variable_scope('refine_stage2'): + net = conv2(net, 3, 256, 1, scope='conv') + net = batchnorm(net, FLAGS.is_training) + net = prelu_tf(net) + + with tf.variable_scope('output_stage'): + net = conv2(net, 9, gen_output_channels, 1, scope='conv') + print(net.get_shape(), 'output_stage') + + return net + + +# Definition of the discriminator +def discriminator(dis_inputs, FLAGS=None): + if FLAGS is None: + raise ValueError('No FLAGS is provided for generator') + + # Define the discriminator block + def discriminator_block(inputs, output_channel, kernel_size, stride, scope): + with tf.variable_scope(scope): + net = conv2(inputs, kernel_size, output_channel, stride, use_bias=False, scope='conv1') + net = batchnorm(net, FLAGS.is_training) + net = lrelu(net, 0.2) + + return net + + with tf.device('/gpu:0'): + with tf.variable_scope('discriminator_unit'): + # The input layer + with tf.variable_scope('input_stage'): + net = conv2(dis_inputs, 3, 64, 1, scope='conv') + net = lrelu(net, 0.2) + + # The discriminator block part + # block 1 + net = discriminator_block(net, 64, 3, 2, 'disblock_1') + + # block 2 + net = discriminator_block(net, 128, 3, 1, 'disblock_2') + + # block 3 + net = discriminator_block(net, 128, 3, 2, 'disblock_3') + + # block 4 + net = discriminator_block(net, 256, 3, 1, 'disblock_4') + + # block 5 + net = discriminator_block(net, 256, 3, 2, 'disblock_5') + + # block 6 + net = discriminator_block(net, 512, 3, 1, 'disblock_6') + + # block_7 + net = discriminator_block(net, 512, 3, 2, 'disblock_7') + + # The dense layer 1 + with tf.variable_scope('dense_layer_1'): + net = slim.flatten(net) + net = denselayer(net, 1024) + net = lrelu(net, 0.2) + + # The dense layer 2 + with tf.variable_scope('dense_layer_2'): + net = denselayer(net, 1) + net = tf.nn.sigmoid(net) + + return net + + +# Define the feature extractor +def VGG19_keras(type): + # Define the feature to extract according to the type of perceptual + if type == 'VGG54': + target_layer = 'block5_conv4' + elif type == 'VGG22': + target_layer = 'block2_conv2' + else: + raise NotImplementedError('Unknown perceptual type') + # Define the base model + base_model = VGG19(include_top=False, weights='imagenet') + extractor = Model(inputs=base_model.inputs, outputs=base_model.get_layer(target_layer).output) + + return extractor + + +def VGG19_slim(input, type, reuse, scope): + # Define the feature to extract according to the type of perceptual + if type == 'VGG54': + target_layer = scope + 'vgg_19/conv5/conv5_4' + elif type == 'VGG22': + target_layer = scope + 'vgg_19/conv2/conv2_2' + else: + raise NotImplementedError('Unknown perceptual type') + _, output = vgg_19(input, is_training=False, reuse=reuse) + output = output[target_layer] + + return output + + +# Define the whole network architecture +def SRGAN(inputs, targets, FLAGS): + # Define the container of the parameter + Network = collections.namedtuple('Network', 'discrim_real_output, discrim_fake_output, discrim_loss, \ + discrim_grads_and_vars, adversarial_loss, content_loss, gen_grads_and_vars, gen_output, train, global_step, \ + learning_rate') + + # Build the generator part + with tf.variable_scope('generator'): + output_channel = targets.get_shape().as_list()[-1] + gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS) + gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) + + # Build the fake discriminator + with tf.name_scope('fake_discriminator'): + with tf.variable_scope('discriminator', reuse=False): + discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS) + + # Build the real discriminator + with tf.name_scope('real_discriminator'): + with tf.variable_scope('discriminator', reuse=True): + discrim_real_output = discriminator(targets, FLAGS=FLAGS) + + # Use the VGG54 feature + if FLAGS.perceptual_mode == 'VGG54': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + elif FLAGS.perceptual_mode == 'VGG22': + extractor = VGG19_keras(FLAGS.perceptual_mode) + extracted_feature_gen = extractor.call(gen_output) + extracted_feature_target = extractor.call(targets) + + elif FLAGS.perceptual_mode == 'MSE': + extracted_feature_gen = gen_output + extracted_feature_target = targets + + else: + raise NotImplementedError('Unknown perceptual type') + + # Calculating the generator loss + with tf.variable_scope('generator_loss'): + # Content loss + with tf.variable_scope('content_loss'): + # Compute the euclidean distance between the two features + # check=tf.equal(extracted_feature_gen, extracted_feature_target) + diff = extracted_feature_gen - extracted_feature_target + if FLAGS.perceptual_mode == 'MSE': + content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + else: + content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + + with tf.variable_scope('adversarial_loss'): + adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS)) + + gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss + print(adversarial_loss.get_shape()) + print(content_loss.get_shape()) + + # Calculating the discriminator loss + with tf.variable_scope('discriminator_loss'): + discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS) + discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS) + + discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss)) + + # Define the learning rate and global step + with tf.variable_scope('get_learning_rate_and_global_step'): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair) + incr_global_step = tf.assign(global_step, global_step + 1) + + with tf.variable_scope('dicriminator_train'): + discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') + discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars) + discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars) + + with tf.variable_scope('generator_train'): + # Need to wait discriminator to perform train step + with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) + gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) + + #[ToDo] If we do not use moving average on loss?? + exp_averager = tf.train.ExponentialMovingAverage(decay=0.99) + update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss]) + + return Network( + discrim_real_output = discrim_real_output, + discrim_fake_output = discrim_fake_output, + discrim_loss = exp_averager.average(discrim_loss), + discrim_grads_and_vars = discrim_grads_and_vars, + adversarial_loss = exp_averager.average(adversarial_loss), + content_loss = exp_averager.average(content_loss), + gen_grads_and_vars = gen_grads_and_vars, + gen_output = gen_output, + train = tf.group(update_loss, incr_global_step, gen_train), + global_step = global_step, + learning_rate = learning_rate + ) + + +def SRResnet(inputs, targets, FLAGS): + # Define the container of the parameter + Network = collections.namedtuple('Network', 'content_loss, gen_grads_and_vars, gen_output, train, global_step, \ + learning_rate') + + # Build the generator part + with tf.variable_scope('generator'): + output_channel = targets.get_shape().as_list()[-1] + gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS) + gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size * 4, FLAGS.crop_size * 4, 3]) + + # Use the VGG54 feature + if FLAGS.perceptual_mode == 'VGG54': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + elif FLAGS.perceptual_mode == 'VGG22': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + elif FLAGS.perceptual_mode == 'MSE': + extracted_feature_gen = gen_output + extracted_feature_target = targets + + else: + raise NotImplementedError('Unknown perceptual type') + + # Calculating the generator loss + with tf.variable_scope('generator_loss'): + # Content loss + with tf.variable_scope('content_loss'): + # Compute the euclidean distance between the two features + # check=tf.equal(extracted_feature_gen, extracted_feature_target) + diff = extracted_feature_gen - extracted_feature_target + if FLAGS.perceptual_mode == 'MSE': + content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + else: + content_loss = FLAGS.vgg_scaling * tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + + gen_loss = content_loss + + # Define the learning rate and global step + with tf.variable_scope('get_learning_rate_and_global_step'): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, + staircase=FLAGS.stair) + incr_global_step = tf.assign(global_step, global_step + 1) + + with tf.variable_scope('generator_train'): + # Need to wait discriminator to perform train step + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) + gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) + + # [ToDo] If we do not use moving average on loss?? + exp_averager = tf.train.ExponentialMovingAverage(decay=0.99) + update_loss = exp_averager.apply([content_loss]) + + return Network( + content_loss=exp_averager.average(content_loss), + gen_grads_and_vars=gen_grads_and_vars, + gen_output=gen_output, + train=tf.group(update_loss, incr_global_step, gen_train), + global_step=global_step, + learning_rate=learning_rate + ) + + +# Use the denseNet version of the generator +def SRResnet_dense(inputs, targets, FLAGS): + # Define the container of the parameter + Network = collections.namedtuple('Network', 'content_loss, gen_grads_and_vars, gen_output, train, global_step, \ + learning_rate') + + # Build the generator part + with tf.variable_scope('generator'): + output_channel = targets.get_shape().as_list()[-1] + gen_output = generatorDense(inputs, output_channel, reuse=False, FLAGS=FLAGS) + gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size * 4, FLAGS.crop_size * 4, 3]) + + # Use the VGG54 feature + if FLAGS.perceptual_mode == 'VGG54': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + elif FLAGS.perceptual_mode == 'VGG22': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + elif FLAGS.perceptual_mode == 'MSE': + extracted_feature_gen = gen_output + extracted_feature_target = targets + + else: + raise NotImplementedError('Unknown perceptual type') + + # Calculating the generator loss + with tf.variable_scope('generator_loss'): + # Content loss + with tf.variable_scope('content_loss'): + # Compute the euclidean distance between the two features + # check=tf.equal(extracted_feature_gen, extracted_feature_target) + diff = extracted_feature_gen - extracted_feature_target + if FLAGS.perceptual_mode == 'MSE': + content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + else: + content_loss = FLAGS.vgg_scaling * tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + + gen_loss = content_loss + + # Define the learning rate and global step + with tf.variable_scope('get_learning_rate_and_global_step'): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, + staircase=FLAGS.stair) + incr_global_step = tf.assign(global_step, global_step + 1) + + with tf.variable_scope('generator_train'): + # Need to wait discriminator to perform train step + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) + gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) + + # [ToDo] If we do not use moving average on loss?? + # exp_averager = tf.train.ExponentialMovingAverage(decay=0.99) + # update_loss = exp_averager.apply([content_loss]) + + return Network( + content_loss=content_loss, + gen_grads_and_vars=gen_grads_and_vars, + gen_output=gen_output, + train=tf.group(content_loss, incr_global_step, gen_train), + global_step=global_step, + learning_rate=learning_rate + ) + + +# Define the whole network architecture +def network_denoise(inputs, targets, FLAGS): + # Define the container of the parameter + Network = collections.namedtuple('Network', 'discrim_real_output, discrim_fake_output, discrim_loss, \ + discrim_grads_and_vars, adversarial_loss, content_loss, gen_grads_and_vars, gen_output, train, global_step, \ + learning_rate') + + # Build the generator part + with tf.variable_scope('generator'): + output_channel = targets.get_shape().as_list()[-1] + gen_output = generator_denoise(inputs, output_channel, reuse=False, FLAGS=FLAGS) + gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3]) + + # Build the fake discriminator + with tf.name_scope('fake_discriminator'): + with tf.variable_scope('discriminator', reuse=False): + discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS) + + # Build the real discriminator + with tf.name_scope('real_discriminator'): + with tf.variable_scope('discriminator', reuse=True): + discrim_real_output = discriminator(targets, FLAGS=FLAGS) + + # Use the VGG54 feature + if FLAGS.perceptual_mode == 'VGG54': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + elif FLAGS.perceptual_mode == 'VGG22': + with tf.name_scope('vgg19_1') as scope: + extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope) + with tf.name_scope('vgg19_2') as scope: + extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope) + + elif FLAGS.perceptual_mode == 'MSE': + extracted_feature_gen = gen_output + extracted_feature_target = targets + + else: + raise NotImplementedError('Unknown perceptual type') + + # Calculating the generator loss + with tf.variable_scope('generator_loss'): + # Content loss + with tf.variable_scope('content_loss'): + # Compute the euclidean distance between the two features + if (FLAGS.perceptual_mode == 'VGG54') or (FLAGS.perceptual_mode == 'VGG22'): + diff = extracted_feature_gen - extracted_feature_target + content_loss = FLAGS.vgg_scaling * tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + elif FLAGS.perceptual_mode == 'MSE': + diff = targets - gen_output + content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3])) + else: + raise NotImplementedError('Unknown perceptual type') + + with tf.variable_scope('adversarial_loss'): + adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS)) + + gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss + + # Calculating the discriminator loss + with tf.variable_scope('discriminator_loss'): + discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS) + discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS) + + discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss)) + + # Define the learning rate and global step + with tf.variable_scope('get_learning_rate_and_global_step'): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair) + incr_global_step = tf.assign(global_step, global_step + 1) + + with tf.variable_scope('dicriminator_train'): + discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') + discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars) + discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars) + + with tf.variable_scope('generator_train'): + # Need to wait discriminator to perform train step + with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta) + gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars) + gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars) + + #[ToDo] If we do not use moving average on loss?? + exp_averager = tf.train.ExponentialMovingAverage(decay=0.99) + update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss]) + + return Network( + discrim_real_output = discrim_real_output, + discrim_fake_output = discrim_fake_output, + discrim_loss = exp_averager.average(discrim_loss), + discrim_grads_and_vars = discrim_grads_and_vars, + adversarial_loss = exp_averager.average(adversarial_loss), + content_loss = exp_averager.average(content_loss), + gen_grads_and_vars = gen_grads_and_vars, + gen_output = gen_output, + train = tf.group(update_loss, incr_global_step, gen_train), + global_step = global_step, + learning_rate = learning_rate + ) + + +def save_images(fetches, FLAGS, step=None): + image_dir = os.path.join(FLAGS.output_dir, "images") + if not os.path.exists(image_dir): + os.makedirs(image_dir) + + filesets = [] + in_path = fetches['path_LR'] + name, _ = os.path.splitext(os.path.basename(str(in_path))) + fileset = {"name": name, "step": step} + for kind in ["inputs", "outputs", "targets"]: + filename = name + "-" + kind + ".png" + if step is not None: + filename = "%08d-%s" % (step, filename) + fileset[kind] = filename + out_path = os.path.join(image_dir, filename) + contents = fetches[kind][0] + with open(out_path, "wb") as f: + f.write(contents) + filesets.append(fileset) + return filesets + + + + + + + + + + + diff --git a/lib/model_dense.py b/lib/model_dense.py new file mode 100644 index 0000000..38c51c0 --- /dev/null +++ b/lib/model_dense.py @@ -0,0 +1,95 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from lib.ops import * +import collections +import os +import math +import scipy.misc as sic +import numpy as np + + +# The dense layer +def denseConvlayer(layer_inputs, bottleneck_scale, growth_rate, is_training): + # Build the bottleneck operation + net = layer_inputs + net_temp = tf.identity(net) + net = batchnorm(net, is_training) + net = prelu_tf(net, name='Prelu_1') + net = conv2(net, kernel=1, output_channel=bottleneck_scale*growth_rate, stride=1, use_bias=False, scope='conv1x1') + net = batchnorm(net, is_training) + net = prelu_tf(net, name='Prelu_2') + net = conv2(net, kernel=3, output_channel=growth_rate, stride=1, use_bias=False, scope='conv3x3') + + # Concatenate the processed feature to the feature + net = tf.concat([net_temp, net], axis=3) + + return net + + +# The transition layer +def transitionLayer(layer_inputs, output_channel, is_training): + net = layer_inputs + net = batchnorm(net, is_training) + net = prelu_tf(net) + net = conv2(net, 1, output_channel, stride=1, use_bias=False, scope='conv1x1') + + return net + + +# The dense block +def denseBlock(block_inputs, num_layers, bottleneck_scale, growth_rate, FLAGS): + # Build each layer consecutively + net = block_inputs + for i in range(num_layers): + with tf.variable_scope('dense_conv_layer%d'%(i+1)): + net = denseConvlayer(net, bottleneck_scale, growth_rate, FLAGS.is_training) + + return net + + +# Here we define the dense block version generator +def generatorDense(gen_inputs, gen_output_channels, reuse=False, FLAGS=None): + # Check the flag + if FLAGS is None: + raise ValueError('No FLAGS is provided for generator') + + # The main netowrk + with tf.variable_scope('generator_unit', reuse=reuse): + # The input stage + with tf.variable_scope('input_stage'): + net = conv2(gen_inputs, 9, 64, 1, scope='conv') + net = prelu_tf(net) + + # The dense block part + # Define the denseblock configuration + layer_per_block = 16 + bottleneck_scale = 4 + growth_rate = 12 + transition_output_channel = 128 + with tf.variable_scope('denseBlock_1'): + net = denseBlock(net, layer_per_block, bottleneck_scale, growth_rate, FLAGS) + + with tf.variable_scope('transition_layer_1'): + net = transitionLayer(net, transition_output_channel, FLAGS.is_training) + + with tf.variable_scope('subpixelconv_stage1'): + net = conv2(net, 3, 256, 1, scope='conv') + net = pixelShuffler(net, scale=2) + net = prelu_tf(net) + + with tf.variable_scope('subpixelconv_stage2'): + net = conv2(net, 3, 256, 1, scope='conv') + net = pixelShuffler(net, scale=2) + net = prelu_tf(net) + + with tf.variable_scope('output_stage'): + net = conv2(net, 9, gen_output_channels, 1, scope='conv') + + return net + + + + diff --git a/lib/ops.py b/lib/ops.py new file mode 100644 index 0000000..a0ee165 --- /dev/null +++ b/lib/ops.py @@ -0,0 +1,231 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import tensorflow as tf +import tensorflow.contrib.slim as slim +import tensorflow.contrib.keras as keras +import pdb + + +def preprocess(image): + with tf.name_scope("preprocess"): + # [0, 1] => [-1, 1] + return image * 2 - 1 + + +def deprocess(image): + with tf.name_scope("deprocess"): + # [-1, 1] => [0, 1] + return (image + 1) / 2 + + +def preprocessLR(image): + with tf.name_scope("preprocessLR"): + return tf.identity(image) + + +def deprocessLR(image): + with tf.name_scope("deprocessLR"): + return tf.identity(image) + + +# Define the convolution building block +def conv2(batch_input, kernel=3, output_channel=64, stride=1, use_bias=True, scope='conv'): + # kernel: An integer specifying the width and height of the 2D convolution window + with tf.variable_scope(scope): + if use_bias: + return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NHWC', + activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer()) + else: + return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NHWC', + activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer(), + biases_initializer=None) + + +def conv2_NCHW(batch_input, kernel=3, output_channel=64, stride=1, use_bias=True, scope='conv_NCHW'): + # Use NCWH to speed up the inference + # kernel: list of 2 integer specifying the width and height of the 2D convolution window + with tf.variable_scope(scope): + if use_bias: + return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NCWH', + activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer()) + else: + return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NCWH', + activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer(), + biases_initializer=None) + +# [ToDo]: Write a code to test this function, check the reuse of this model +# Define our Prelu +def prelu(inputs, input_shape): + prelu_unit = keras.layers.PReLU() + prelu_unit.build(input_shape) + return prelu_unit.call(inputs) + + +# Define our tensorflow version PRelu +def prelu_tf(inputs, name='Prelu'): + with tf.variable_scope(name): + alphas = tf.get_variable('alpha', inputs.get_shape()[-1], initializer=tf.zeros_initializer(), dtype=tf.float32) + pos = tf.nn.relu(inputs) + neg = alphas * (inputs - abs(inputs)) * 0.5 + + return pos + neg + +# Define our Lrelu +def lrelu(inputs, alpha): + return keras.layers.LeakyReLU(alpha=alpha).call(inputs) + + +def batchnorm(inputs, is_training): + return slim.batch_norm(inputs, decay=0.9, epsilon=0.001, updates_collections=tf.GraphKeys.UPDATE_OPS, + scale=False, fused=True, is_training=is_training) + + +# Our dense layer +def denselayer(inputs, output_size): + output = tf.layers.dense(inputs, output_size, activation=None, kernel_initializer=tf.contrib.layers.xavier_initializer()) + return output + + +# The implementation of PixelShuffler +def pixelShuffler(inputs, scale=2): + size = tf.shape(inputs) + batch_size = size[0] + h = size[1] + w = size[2] + c = inputs.get_shape().as_list()[-1] + + # Get the target channel size + channel_target = c // (scale * scale) + channel_factor = c // channel_target + + shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale] + shape_2 = [batch_size, h * scale, w * scale, 1] + + # Reshape and transpose for periodic shuffling for each channel + input_split = tf.split(inputs, channel_target, axis=3) + output = tf.concat([phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3) + + return output + + +def phaseShift(inputs, scale, shape_1, shape_2): + # Tackle the condition when the batch is None + X = tf.reshape(inputs, shape_1) + X = tf.transpose(X, [0, 1, 3, 2, 4]) + + return tf.reshape(X, shape_2) + + +# The random flip operation used for loading examples +def random_flip(input, decision): + f1 = tf.identity(input) + f2 = tf.image.flip_left_right(input) + output = tf.cond(tf.less(decision, 0.5), lambda: f2, lambda: f1) + + return output + + +# The operation used to print out the configuration +def print_configuration_op(FLAGS): + print('[Configurations]:') + a = FLAGS.mode + #pdb.set_trace() + for name, value in FLAGS.__flags.items(): + if type(value) == float: + print('\t%s: %f'%(name, value)) + elif type(value) == int: + print('\t%s: %d'%(name, value)) + elif type(value) == str: + print('\t%s: %s'%(name, value)) + elif type(value) == bool: + print('\t%s: %s'%(name, value)) + else: + print('\t%s: %s' % (name, value)) + + print('End of configuration') + + +def compute_psnr(ref, target): + ref = tf.cast(ref, tf.float32) + target = tf.cast(target, tf.float32) + diff = tf.subtract(target, ref) + sqr = tf.multiply(diff, diff) + err = tf.reduce_sum(sqr) + v = tf.shape(diff)[0] * tf.shape(diff)[1] * tf.shape(diff)[2] * tf.shape(diff)[3] + mse = tf.div(err, tf.cast(v, tf.float32)) + c10 = tf.constant(10, tf.float32) + c255_2 = tf.multiply(tf.constant(255,tf.float32),tf.constant(255, tf.float32)) + psnr = tf.multiply(c10, tf.div(tf.log(tf.div(c255_2, mse)), tf.log(c10))) + + return psnr + + +# VGG19 component +def vgg_arg_scope(weight_decay=0.0005): + """Defines the VGG arg scope. + Args: + weight_decay: The l2 regularization coefficient. + Returns: + An arg_scope. + """ + with slim.arg_scope([slim.conv2d, slim.fully_connected], + activation_fn=tf.nn.relu, + weights_regularizer=slim.l2_regularizer(weight_decay), + biases_initializer=tf.zeros_initializer()): + with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: + return arg_sc + + +# VGG19 net +def vgg_19(inputs, + num_classes=1000, + is_training=False, + dropout_keep_prob=0.5, + spatial_squeeze=True, + scope='vgg_19', + reuse = False, + fc_conv_padding='VALID'): + """Oxford Net VGG 19-Layers version E Example. + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 224x224. + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + is_training: whether or not the model is being trained. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + fc_conv_padding: the type of padding to use for the fully connected layer + that is implemented as a convolutional layer. Use 'SAME' padding if you + are applying the network in a fully convolutional manner and want to + get a prediction map downsampled by a factor of 32 as an output. Otherwise, + the output prediction map will be (input / 32) - 6 in case of 'VALID' padding. + Returns: + the last op containing the log predictions and end_points dict. + """ + with tf.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc: + end_points_collection = sc.name + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d. + with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], + outputs_collections=end_points_collection): + net = slim.repeat(inputs, 2, slim.conv2d, 64, 3, scope='conv1', reuse=reuse) + net = slim.max_pool2d(net, [2, 2], scope='pool1') + net = slim.repeat(net, 2, slim.conv2d, 128, 3, scope='conv2',reuse=reuse) + net = slim.max_pool2d(net, [2, 2], scope='pool2') + net = slim.repeat(net, 4, slim.conv2d, 256, 3, scope='conv3', reuse=reuse) + net = slim.max_pool2d(net, [2, 2], scope='pool3') + net = slim.repeat(net, 4, slim.conv2d, 512, 3, scope='conv4',reuse=reuse) + net = slim.max_pool2d(net, [2, 2], scope='pool4') + net = slim.repeat(net, 4, slim.conv2d, 512, 3, scope='conv5',reuse=reuse) + net = slim.max_pool2d(net, [2, 2], scope='pool5') + # Use conv2d instead of fully_connected layers. + # Convert end_points_collection into a end_point dict. + end_points = slim.utils.convert_collection_to_dict(end_points_collection) + + return net, end_points +vgg_19.default_image_size = 224 diff --git a/main.py b/main.py new file mode 100644 index 0000000..06e3d3e --- /dev/null +++ b/main.py @@ -0,0 +1,343 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.slim as slim +import tensorflow.contrib.keras as keras +import os +from lib.model import data_loader2, generator, SRGAN, test_data_loader, save_images, SRResnet, \ + SRResnet_dense, generator_denoise +from lib.ops import * +import math +import time +import numpy as np +# import ipdb + +Flags = tf.app.flags + +# The system parameter +Flags.DEFINE_string('output_dir', None, 'The output directory of the checkpoint') +Flags.DEFINE_string('summary_dir', None, 'The dirctory to output the summary') +Flags.DEFINE_string('mode', 'train', 'The mode of the model train, test, inference') +Flags.DEFINE_string('checkpoint', None, 'The checkpoint needed to be restored') +Flags.DEFINE_boolean('pre_trained_model', False, 'Whether using the pretrained model') +Flags.DEFINE_boolean('is_training', True, 'Whether we are in the training phase') +Flags.DEFINE_string('vgg_ckpt', './vgg19/vgg_19.ckpt', 'checkpoint file for the vgg19') +Flags.DEFINE_string('task', None, 'The task: SRGAN, SRResnet') +Flags.DEFINE_string('generator_type', None, 'The type of the generator: "original", "denseNet".') +# The data preparing operation +Flags.DEFINE_integer('batch_size', 128, 'Batch size of input') +Flags.DEFINE_string('input_dir_LR', None, 'The directory of the input resolution input data') +Flags.DEFINE_string('input_dir_HR', None, 'The directory of the high resolution input data') +Flags.DEFINE_integer('train_image_width', 154, 'The width of the training image (low resolution one') +Flags.DEFINE_integer('train_image_height', 102, 'The height of the traing image (low resolution one') +Flags.DEFINE_boolean('flip', True, 'Whether data augmentation is applied') +Flags.DEFINE_boolean('random_crop', True, 'Whether perform the random crop') +Flags.DEFINE_integer('crop_size', 24, 'The crop size of the training image') +Flags.DEFINE_integer('name_queue_capacity', 512, 'The capacity of the filename queue') +Flags.DEFINE_integer('image_queue_capacity', 1024, 'The capacity of the image queue') +Flags.DEFINE_integer('queue_thread', 20, 'The threads of the queue') +# Generator configuration +Flags.DEFINE_integer('num_resblock', 16, 'How many residual blocks are there in the generator') +# The content loss parameter +Flags.DEFINE_string('perceptual_mode', 'VGG54', 'The type of feature used in perceptual loss') +Flags.DEFINE_float('EPS', 1e-12, 'The eps added to prevent nan') +Flags.DEFINE_float('ratio', 0.001, 'The ratio between content loss and adversarial loss') +Flags.DEFINE_float('vgg_scaling', 0.0061, 'The scaling factor for the perceptual loss if using vgg perceptual loss') +# The training parameters +Flags.DEFINE_float('learning_rate', 0.0001, 'The learning rate for the network') +Flags.DEFINE_integer('decay_step', 20000, 'The steps needed to decay the learning rate') +Flags.DEFINE_float('decay_rate', 0.1, 'The decay rate of each decay step') +Flags.DEFINE_boolean('stair', False, 'Whether perform staircase decay') +Flags.DEFINE_float('beta', 0.9, 'The beta1 parameter for the Adam optimizer') +Flags.DEFINE_integer('max_epoch', None, 'The max epoch for the training') +Flags.DEFINE_integer('max_iter', 1000000, 'The max iteration of the training') +Flags.DEFINE_integer('display_freq', 20, 'The diplay frequency of the training process') +Flags.DEFINE_integer('summary_freq', 100, 'The frequency of writing summary') +Flags.DEFINE_integer('save_freq', 10000, 'The frequency of saving images') + + +FLAGS = Flags.FLAGS + +# Print the configuration of the model +print_configuration_op(FLAGS) + +# Check the output_dir is given +if FLAGS.output_dir == None: + raise ValueError('The output directory is needed') + +# Check the output directory to save the checkpoint +if not os.path.exists(FLAGS.output_dir): + os.mkdir(FLAGS.output_dir) + +# Check the summary directory to save the event +if not os.path.exists(FLAGS.summary_dir): + os.mkdir(FLAGS.summary_dir) + +# The testing mode +if FLAGS.mode == 'test': + # Check the checkpoint + if FLAGS.checkpoint == None: + raise ValueError('The checkpoint is needed to performing the test.') + + # In the testing time, no flip and crop is needed + if FLAGS.flip == True: + FLAGS.flip = False + + if FLAGS.crop_size != None: + FLAGS.crop_size = None + + # Declare the test data reader + test_data = test_data_loader(FLAGS) + + inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw') + targets_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='targets_raw') + path_LR = tf.placeholder(tf.string, shape=[], name='path_LR') + path_HR = tf.placeholder(tf.string, shape=[], name='path_HR') + + with tf.variable_scope('generator'): + if FLAGS.task == 'denoise': + gen_output = generator_denoise(inputs_raw, 3, reuse=False, FLAGS=FLAGS) + elif FLAGS.task == 'SRGAN': + gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS) + + print('Finish building the network') + + with tf.name_scope('convert_image'): + # Deprocess the images outputed from the model + inputs = deprocessLR(inputs_raw) + targets = deprocess(targets_raw) + outputs = deprocess(gen_output) + + # Convert back to uint8 + converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True) + converted_targets = tf.image.convert_image_dtype(targets, dtype=tf.uint8, saturate=True) + converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True) + + with tf.name_scope('encode_image'): + save_fetch = { + "path_LR": path_LR, + "path_HR": path_HR, + "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name='input_pngs'), + "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name='output_pngs'), + "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name='target_pngs') + } + + # Define the weight initiallizer + var_list2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator') + weight_initiallizer = tf.train.Saver(var_list2) + + # Define the initialization operation + init_op = tf.global_variables_initializer() + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + # Load the pretrained model + print('Loading weights from the pre-trained model') + weight_initiallizer.restore(sess, FLAGS.checkpoint) + + max_iter = len(test_data.inputs) + print('Evaluation starts!!') + for i in range(max_iter): + input_im = np.array([test_data.inputs[i]]).astype(np.float32) + target_im = np.array([test_data.targets[i]]).astype(np.float32) + path_lr = test_data.paths_LR[i] + path_hr = test_data.paths_HR[i] + results = sess.run(save_fetch, feed_dict={inputs_raw: input_im, targets_raw: target_im, + path_LR: path_lr, path_HR: path_hr}) + filesets = save_images(results, FLAGS) + for i, f in enumerate(filesets): + print('evaluate image', f['name']) + +# The training mode +elif FLAGS.mode == 'train': + # Load data for training and testing + # ToDo Add online noise adding and downscaling + data = data_loader2(FLAGS) + print('Data count = %d' % (data.image_count)) + + # Connect to the network + if FLAGS.task == 'SRGAN': + if FLAGS.generator_type == 'original': + Net = SRGAN(data.inputs, data.targets, FLAGS) + # elif FLAGS.generator_type == 'denseNet': + # Net = SRResnet_dense(data.inputs, data.targets, FLAGS) + else: + raise ValueError('Unknown generator type') + elif FLAGS.task =='SRResnet': + if FLAGS.generator_type == 'original': + Net = SRResnet(data.inputs, data.targets, FLAGS) + elif FLAGS.generator_type == 'denseNet': + Net = SRResnet_dense(data.inputs, data.targets, FLAGS) + else: + raise ValueError('Unknown generator type') + + print('Finish building the network!!') + + # Convert the images output from the network + with tf.name_scope('convert_image'): + # Deprocess the images outputed from the model + inputs = deprocessLR(data.inputs) + targets = deprocess(data.targets) + outputs = deprocess(Net.gen_output) + + # Convert back to uint8 + converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True) + converted_targets = tf.image.convert_image_dtype(targets, dtype=tf.uint8, saturate=True) + converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True) + + # Compute PSNR + with tf.name_scope("compute_psnr"): + psnr = compute_psnr(converted_targets, converted_outputs) + + # Add image summaries + with tf.name_scope('inputs_summary'): + tf.summary.image('input_summary', converted_inputs) + + with tf.name_scope('targets_summary'): + tf.summary.image('target_summary', converted_targets) + + with tf.name_scope('outputs_summary'): + tf.summary.image('outputs_summary', converted_outputs) + + # Add scalar summary + if FLAGS.task == 'SRGAN': + tf.summary.scalar('discriminator_loss', Net.discrim_loss) + tf.summary.scalar('adversarial_loss', Net.adversarial_loss) + tf.summary.scalar('content_loss', Net.content_loss) + tf.summary.scalar('generator_loss', Net.content_loss + FLAGS.ratio*Net.adversarial_loss) + tf.summary.scalar('PSNR', psnr) + tf.summary.scalar('learning_rate', Net.learning_rate) + elif FLAGS.task == 'SRResnet': + tf.summary.scalar('content_loss', Net.content_loss) + tf.summary.scalar('generator_loss', Net.content_loss) + tf.summary.scalar('PSNR', psnr) + tf.summary.scalar('learning_rate', Net.learning_rate) + + + # Define the saver and weight initiallizer + saver = tf.train.Saver(max_to_keep=10) + + var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + if FLAGS.task == 'SRGAN': + #var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + \ + # tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') + var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + elif FLAGS.task == 'SRResnet': + var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + + weight_initiallizer = tf.train.Saver(var_list2) + + # When using MSE loss, no need to restore the vgg net + if not FLAGS.perceptual_mode == 'MSE': + vgg_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='vgg_19') + vgg_restore = tf.train.Saver(vgg_var_list) + # print(vgg_var_list) + # init_op = tf.global_variables_initializer() + # merged_summary = tf.summary.merge_all() + # summary_op = tf.summary.FileWriter(logdir=FLAGS.summary_dir) + # Start the session + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + # with tf.Session(config=config) as sess: + # Use superviser to coordinate all queue and summary writer + sv = tf.train.Supervisor(logdir=FLAGS.summary_dir, save_summaries_secs=0, saver=None) + with sv.managed_session(config=config) as sess: + if (FLAGS.checkpoint is not None) and (FLAGS.pre_trained_model is False): + print('Loading model from the checkpoint...') + checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint) + saver.restore(sess, checkpoint) + + elif (FLAGS.checkpoint is not None) and (FLAGS.pre_trained_model is True): + print('Loading weights from the pre-trained model') + weight_initiallizer.restore(sess, FLAGS.checkpoint) + + if not FLAGS.perceptual_mode == 'MSE': + vgg_restore.restore(sess, FLAGS.vgg_ckpt) + print('VGG19 restored successfully!!') + + # Coordinate from multiple threads + # coord = tf.train.Coordinator() + # threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + # Performing the training + if FLAGS.max_epoch is None: + if FLAGS.max_iter is None: + raise ValueError('one of max_epoch or max_iter should be provided') + else: + max_iter = FLAGS.max_iter + else: + max_iter = FLAGS.max_epoch * data.steps_per_epoch + + print('Optimization starts!!!') + start = time.time() + for step in range(max_iter): + fetches = { + "train": Net.train, + "global_step": sv.global_step, + } + + if ((step+1) % FLAGS.display_freq) == 0: + if FLAGS.task == 'SRGAN': + fetches["discrim_loss"] = Net.discrim_loss + fetches["adversarial_loss"] = Net.adversarial_loss + fetches["content_loss"] = Net.content_loss + fetches["PSNR"] = psnr + fetches["learning_rate"] = Net.learning_rate + fetches["global_step"] = Net.global_step + elif FLAGS.task == 'SRResnet': + fetches["content_loss"] = Net.content_loss + fetches["PSNR"] = psnr + fetches["learning_rate"] = Net.learning_rate + fetches["global_step"] = Net.global_step + + if ((step+1) % FLAGS.summary_freq) == 0: + # fetches["summary"] = merged_summary + fetches["summary"] = sv.summary_op + + results = sess.run(fetches) + + if ((step + 1) % FLAGS.summary_freq) == 0: + print('Recording summary!!') + # summary_op.add_summary(results['summary'], results['global_step']) + sv.summary_writer.add_summary(results['summary'], results['global_step']) + + if ((step + 1) % FLAGS.display_freq) == 0: + train_epoch = math.ceil(results["global_step"] / data.steps_per_epoch) + train_step = (results["global_step"] - 1) % data.steps_per_epoch + 1 + rate = (step + 1) * FLAGS.batch_size / (time.time() - start) + remaining = (max_iter - step) * FLAGS.batch_size / rate + print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60)) + if FLAGS.task == 'SRGAN': + print("global_step", results["global_step"]) + print("PSNR", results["PSNR"]) + print("discrim_loss", results["discrim_loss"]) + print("adversarial_loss", results["adversarial_loss"]) + print("content_loss", results["content_loss"]) + print("learning_rate", results['learning_rate']) + elif FLAGS.task == 'SRResnet': + print("global_step", results["global_step"]) + print("PSNR", results["PSNR"]) + print("content_loss", results["content_loss"]) + print("learning_rate", results['learning_rate']) + #print('check', results['check']) + + + if ((step +1) % FLAGS.save_freq) == 0: + print('Save the checkpoint') + saver.save(sess, os.path.join(FLAGS.output_dir, 'model'), global_step=sv.global_step) + + # coord.request_stop() + # coord.join(threads) + + print('Optimization done!!!!!!!!!!!!') + + + + + + + + diff --git a/test_SRGAN.sh b/test_SRGAN.sh new file mode 100644 index 0000000..7c39ae7 --- /dev/null +++ b/test_SRGAN.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +CUDA_VISIBLE_DEVICES=0 python main.py \ + --output_dir ./val_result/SRGAN_VGG54/ \ + --summary_dir ./val_result/SRGAN_VGG54/log/ \ + --mode test \ + --is_training False \ + --task SRGAN \ + --batch_size 16 \ + --input_dir_LR ./data/Set14_LR/ \ + --input_dir_HR ./data/Set14_HR/ \ + --num_resblock 16 \ + --perceptual_mode VGG54 \ + --pre_trained_model True \ + --checkpoint ./experiment_SRGAN_VGG54/model-200000 + diff --git a/tool/__init__.py b/tool/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tool/convertPNG.py b/tool/convertPNG.py new file mode 100644 index 0000000..e69de29 diff --git a/tool/resizeImage.py b/tool/resizeImage.py new file mode 100644 index 0000000..07584b4 --- /dev/null +++ b/tool/resizeImage.py @@ -0,0 +1,39 @@ +from PIL import Image +import os +from multiprocessing import Pool + +# Define the input and output image +input_dir = '../data/RAISE_HR/' +output_dir = '../data/RAISE_LR/' +scale = 4. + +if not os.path.exists(output_dir): + os.mkdir(output_dir) + +image_list = os.listdir(input_dir) +image_list = [os.path.join(input_dir, _) for _ in image_list] + +# Define the pool function +def downscale(name): + print(name) + with Image.open(name) as im: + w, h = im.size + w_new = int(w / scale) + h_new = int(h / scale) + im_new = im.resize((w_new, h_new), Image.ANTIALIAS) + + save_name = os.path.join(output_dir, name.split('/')[-1]) + im_new.save(save_name) + +p = Pool(5) +p.map(downscale, image_list) +# for name in image_list: +# print name +# with Image.open(name) as im: +# w, h = im.size +# w_new = int(w / scale) +# h_new = int(w / scale) +# im.resize((w_new, h_new), Image.ANTIALIAS) +# +# save_name = os.path.join(output_dir, name.split('/')[-1].split('-0')[0]+'.png') +# im.save(save_name) \ No newline at end of file diff --git a/train_SRGAN.sh b/train_SRGAN.sh new file mode 100644 index 0000000..6544023 --- /dev/null +++ b/train_SRGAN.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +CUDA_VISIBLE_DEVICES=0 python main.py \ + --output_dir ./experiment_SRGAN_VGG54/ \ + --summary_dir ./experiment_SRGAN_VGG54/log/ \ + --mode train \ + --is_training True \ + --task SRGAN \ + --batch_size 16 \ + --flip True \ + --random_crop True \ + --crop_size 24 \ + --input_dir_LR ./data/RAISE_LR/ \ + --input_dir_HR ./data/RAISE_HR/ \ + --num_resblock 16 \ + --perceptual_mode VGG54 \ + --name_queue_capacity 4096 \ + --image_queue_capacity 4096 \ + --ratio 0.001 \ + --learning_rate 0.0001 \ + --decay_step 100000 \ + --decay_rate 0.1 \ + --stair True \ + --beta 0.9 \ + --max_iter 200000 \ + --queue_thread 10 \ + --vgg_scaling 0.0061 \ + --pre_trained_model True \ + --checkpoint ./experiment_SRGAN_MSE/model-500000 + diff --git a/train_SRResnet.sh b/train_SRResnet.sh new file mode 100644 index 0000000..30e6d82 --- /dev/null +++ b/train_SRResnet.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +CUDA_VISIBLE_DEVICES=0 python main.py \ + --output_dir ./experiment_SRResnet_dense/ \ + --summary_dir ./experiment_SRResnet_dense/log/ \ + --mode train \ + --is_training True \ + --task SRResnet \ + --generator_type denseNet \ + --batch_size 16 \ + --flip True \ + --random_crop True \ + --crop_size 24 \ + --input_dir_LR ./data/RAISE_LR/ \ + --input_dir_HR ./data/RAISE_HR/ \ + --num_resblock 16 \ + --name_queue_capacity 4096 \ + --image_queue_capacity 4096 \ + --perceptual_mode MSE \ + --queue_thread 12 \ + --ratio 0.001 \ + --learning_rate 0.0001 \ + --decay_step 400000 \ + --decay_rate 0.1 \ + --stair False \ + --beta 0.9 \ + --max_iter 1000000 \ + --save_freq 20000 +