diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ddbcc37
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+*.pyc
+*.ckpt
\ No newline at end of file
diff --git a/.idea/SRGAN_tensorflow.iml b/.idea/SRGAN_tensorflow.iml
new file mode 100644
index 0000000..6f63a63
--- /dev/null
+++ b/.idea/SRGAN_tensorflow.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..795374c
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..4e2467f
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000..a56e092
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,132 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1500209151200
+
+
+ 1500209151200
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/lib/__init__.py b/lib/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lib/model.py b/lib/model.py
new file mode 100644
index 0000000..c10435f
--- /dev/null
+++ b/lib/model.py
@@ -0,0 +1,878 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from lib.ops import *
+from lib.model_dense import generatorDense
+import collections
+import os
+import math
+from tensorflow.contrib.keras.api.keras.applications.vgg19 import VGG19
+from tensorflow.contrib.keras.api.keras.models import Model
+import scipy.misc as sic
+import numpy as np
+
+
+# Define the data loader
+def data_loader(FLAGS):
+ # Define the returned data batches
+ Data = collections.namedtuple('Data', 'paths, inputs, targets, image_count, steps_per_epoch')
+
+ #Check the input directory
+ if FLAGS.input_dir == 'None':
+ raise ValueError('Input directory is not provided')
+
+ if not os.path.exists(FLAGS.input_dir):
+ raise ValueError('Input directory not found')
+
+ image_list = os.listdir(FLAGS.input_dir)
+ image_list = [_ for _ in image_list if _.endswith('.png')]
+ if len(image_list)==0:
+ raise Exception('No png files in the input directory')
+
+ image_list = sorted(image_list)
+ image_list = [os.path.join(FLAGS.input_dir, _) for _ in image_list]
+
+ with tf.variable_scope('load_image'):
+ # define the image list queue
+ image_list_queue = tf.train.string_input_producer(image_list, shuffle=(FLAGS.mode == 'Train'), capacity=FLAGS.name_queue_capacity)
+ print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train'))
+
+ # Reading and decode the images
+ reader = tf.WholeFileReader(name='image_reader')
+ paths, image = reader.read(image_list_queue)
+ input_image = tf.image.decode_png(image)
+ input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32)
+
+ assertion = tf.assert_equal(tf.shape(input_image)[2], 3, message="image does not have 3 channels")
+ with tf.control_dependencies([assertion]):
+ input_image = tf.identity(input_image)
+
+ # Break apart the paired images and normalize to [-1, 1]
+ width = tf.shape(input_image)[1]
+ a_image = preprocess(input_image[:, :(width // 2), :])
+ b_image = preprocess(input_image[:, (width // 2):, :])
+
+
+ if FLAGS.which_direction == "AtoB":
+ inputs, targets = [a_image, b_image]
+ elif FLAGS.which_direction == "BtoA":
+ inputs, targets = [b_image, a_image]
+ else:
+ raise Exception("invalid direction")
+
+ # The data augmentation part
+ with tf.variable_scope('data_preprocessing'):
+ with tf.variable_scope('random_crop'):
+ # Check whether perform crop
+ if (FLAGS.random_crop is True) and (FLAGS.crop_size < FLAGS.train_image_width and
+ FLAGS.crop_size < FLAGS.train_image_height) and FLAGS.mode == 'train':
+ print('[Config] Use random crop')
+ inputs.set_shape([FLAGS.train_image_height, FLAGS.train_image_width, 3])
+ targets.set_shape([FLAGS.train_image_height, FLAGS.train_image_width, 3])
+ offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, FLAGS.train_image_width - FLAGS.crop_size)), dtype=tf.int32)
+ offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, FLAGS.train_image_height - FLAGS.crop_size)), dtype=tf.int32)
+
+ inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size, FLAGS.crop_size)
+ targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w, FLAGS.crop_size, FLAGS.crop_size)
+ # Do not perform crop
+ else:
+ inputs = tf.identity(inputs)
+ targets = tf.identity(targets)
+
+ with tf.variable_scope('random_flip'):
+ # Check for random flip:
+ if (FLAGS.flip is True) and (FLAGS.mode == 'train'):
+ print('[Config] Use random flip')
+ # Produce the decision of random flip
+ decision = tf.random_uniform([], 0, 1, dtype=tf.float32)
+
+ input_images = random_flip(inputs, decision)
+ target_images = random_flip(targets, decision)
+ else:
+ input_images = tf.identity(inputs)
+ target_images = tf.identity(targets)
+
+ if FLAGS.mode == 'train':
+ paths_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([paths, input_images, target_images],
+ batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity,
+ min_after_dequeue=512, num_threads=20)
+ else:
+ paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images],
+ batch_size=FLAGS.batch_size, num_threads=20, allow_smaller_final_batch=True)
+
+ steps_per_epoch = int(math.ceil(len(image_list) / FLAGS.batch_size))
+
+ return Data(
+ paths=paths_batch,
+ inputs=inputs_batch,
+ targets=targets_batch,
+ image_count=len(image_list),
+ steps_per_epoch=steps_per_epoch
+ )
+
+
+def data_loader2(FLAGS):
+ with tf.device('/cpu:0'):
+ # Define the returned data batches
+ Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets, image_count, steps_per_epoch')
+
+ #Check the input directory
+ if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'):
+ raise ValueError('Input directory is not provided')
+
+ if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)):
+ raise ValueError('Input directory not found')
+
+ image_list_LR = os.listdir(FLAGS.input_dir_LR)
+ image_list_LR = [_ for _ in image_list_LR if _.endswith('.png')]
+ if len(image_list_LR)==0:
+ raise Exception('No png files in the input directory')
+
+ image_list_LR_temp = sorted(image_list_LR)
+ image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp]
+ image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp]
+
+ image_list_LR_tensor = tf.convert_to_tensor(image_list_LR, dtype=tf.string)
+ image_list_HR_tensor = tf.convert_to_tensor(image_list_HR, dtype=tf.string)
+
+ with tf.variable_scope('load_image'):
+ # define the image list queue
+ # image_list_LR_queue = tf.train.string_input_producer(image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity)
+ # image_list_HR_queue = tf.train.string_input_producer(image_list_HR, shuffle=False, capacity=FLAGS.name_queue_capacity)
+ #print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train'))
+ output = tf.train.slice_input_producer([image_list_LR_tensor, image_list_HR_tensor],
+ shuffle=False, capacity=FLAGS.name_queue_capacity)
+
+ # Reading and decode the images
+ reader = tf.WholeFileReader(name='image_reader')
+ image_LR = tf.read_file(output[0])
+ image_HR = tf.read_file(output[1])
+ input_image_LR = tf.image.decode_png(image_LR, channels=3)
+ input_image_HR = tf.image.decode_png(image_HR, channels=3)
+ input_image_LR = tf.image.convert_image_dtype(input_image_LR, dtype=tf.float32)
+ input_image_HR = tf.image.convert_image_dtype(input_image_HR, dtype=tf.float32)
+
+ assertion = tf.assert_equal(tf.shape(input_image_LR)[2], 3, message="image does not have 3 channels")
+ with tf.control_dependencies([assertion]):
+ input_image_LR = tf.identity(input_image_LR)
+ input_image_HR = tf.identity(input_image_HR)
+
+ # Normalize the low resolution image to [0, 1], high resolution to [-1, 1]
+ a_image = preprocessLR(input_image_LR)
+ b_image = preprocess(input_image_HR)
+
+ inputs, targets = [a_image, b_image]
+
+ # The data augmentation part
+ with tf.name_scope('data_preprocessing'):
+ with tf.name_scope('random_crop'):
+ # Check whether perform crop
+ if (FLAGS.random_crop is True) and FLAGS.mode == 'train':
+ print('[Config] Use random crop')
+ # Set the shape of the input image. the target will have 4X size
+ input_size = tf.shape(inputs)
+ target_size = tf.shape(targets)
+ offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)),
+ dtype=tf.int32)
+ offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)),
+ dtype=tf.int32)
+
+ if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
+ inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
+ FLAGS.crop_size)
+ targets = tf.image.crop_to_bounding_box(targets, offset_h*4, offset_w*4, FLAGS.crop_size*4,
+ FLAGS.crop_size*4)
+ elif FLAGS.task == 'denoise':
+ inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
+ FLAGS.crop_size)
+ targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w,
+ FLAGS.crop_size, FLAGS.crop_size)
+ # Do not perform crop
+ else:
+ inputs = tf.identity(inputs)
+ targets = tf.identity(targets)
+
+ with tf.variable_scope('random_flip'):
+ # Check for random flip:
+ if (FLAGS.flip is True) and (FLAGS.mode == 'train'):
+ print('[Config] Use random flip')
+ # Produce the decision of random flip
+ decision = tf.random_uniform([], 0, 1, dtype=tf.float32)
+
+ input_images = random_flip(inputs, decision)
+ target_images = random_flip(targets, decision)
+ else:
+ input_images = tf.identity(inputs)
+ target_images = tf.identity(targets)
+
+ if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
+ input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
+ target_images.set_shape([FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
+ elif FLAGS.task == 'denoise':
+ input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
+ target_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
+
+ if FLAGS.mode == 'train':
+ paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([output[0], output[1], input_images, target_images],
+ batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity+4*FLAGS.batch_size,
+ min_after_dequeue=FLAGS.image_queue_capacity, num_threads=FLAGS.queue_thread)
+ else:
+ paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.batch([output[0], output[1], input_images, target_images],
+ batch_size=FLAGS.batch_size, num_threads=FLAGS.queue_thread, allow_smaller_final_batch=True)
+
+ steps_per_epoch = int(math.ceil(len(image_list_LR) / FLAGS.batch_size))
+ if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
+ inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
+ targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
+ elif FLAGS.task == 'denoise':
+ inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
+ targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
+ return Data(
+ paths_LR=paths_LR_batch,
+ paths_HR=paths_HR_batch,
+ inputs=inputs_batch,
+ targets=targets_batch,
+ image_count=len(image_list_LR),
+ steps_per_epoch=steps_per_epoch
+ )
+
+
+# The test data loader. Allow input image with different size
+def test_data_loader(FLAGS):
+ # Get the image name list
+ if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'):
+ raise ValueError('Input directory is not provided')
+
+ if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)):
+ raise ValueError('Input directory not found')
+
+ image_list_LR_temp = os.listdir(FLAGS.input_dir_LR)
+ image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png']
+ image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png']
+
+ # Read in and preprocess the images
+ def preprocess_test(name, mode):
+ im = sic.imread(name).astype(np.float32)
+ # check grayscale image
+ if im.shape[-1] != 3:
+ h, w = im.shape
+ temp = np.empty((h, w, 3), dtype=np.uint8)
+ temp[:, :, :] = im[:, :, np.newaxis]
+ im = temp.copy()
+ if mode == 'LR':
+ im = im / np.max(im)
+ elif mode == 'HR':
+ im = im / np.max(im)
+ im = im * 2 - 1
+
+ return im
+
+ image_LR = [preprocess_test(_, 'LR') for _ in image_list_LR]
+ image_HR = [preprocess_test(_, 'HR') for _ in image_list_HR]
+
+ # Push path and image into a list
+ Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets')
+
+ return Data(
+ paths_LR = image_list_LR,
+ paths_HR = image_list_HR,
+ inputs = image_LR,
+ targets = image_HR
+ )
+
+
+# Definition of the generator
+def generator(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
+ # Check the flag
+ if FLAGS is None:
+ raise ValueError('No FLAGS is provided for generator')
+
+ # The Bx residual blocks
+ def residual_block(inputs, output_channel, stride, scope):
+ with tf.variable_scope(scope):
+ net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1')
+ net = batchnorm(net, FLAGS.is_training)
+ net = prelu_tf(net)
+ net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2')
+ net = batchnorm(net, FLAGS.is_training)
+ net = net + inputs
+
+ return net
+
+
+ with tf.variable_scope('generator_unit', reuse=reuse):
+ # The input layer
+ with tf.variable_scope('input_stage'):
+ net = conv2(gen_inputs, 9, 64, 1, scope='conv')
+ net = prelu_tf(net)
+
+ stage1_output = net
+
+ # The residual block parts
+ for i in range(1, FLAGS.num_resblock+1 , 1):
+ name_scope = 'resblock_%d'%(i)
+ net = residual_block(net, 64, 1, name_scope)
+
+ with tf.variable_scope('resblock_output'):
+ net = conv2(net, 3, 64, 1, use_bias=False, scope='conv')
+ net = batchnorm(net, FLAGS.is_training)
+
+ net = net + stage1_output
+
+ with tf.variable_scope('subpixelconv_stage1'):
+ net = conv2(net, 3, 256, 1, scope='conv')
+ net = pixelShuffler(net, scale=2)
+ net = prelu_tf(net)
+
+ with tf.variable_scope('subpixelconv_stage2'):
+ net = conv2(net, 3, 256, 1, scope='conv')
+ net = pixelShuffler(net, scale=2)
+ net = prelu_tf(net)
+
+ with tf.variable_scope('output_stage'):
+ net = conv2(net, 9, gen_output_channels, 1, scope='conv')
+
+ return net
+
+
+# Define the generator for the denoise task
+def generator_denoise(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
+ # Check the flag
+ if FLAGS is None:
+ raise ValueError('No FLAGS is provided for generator')
+
+ # The Bx residual blocks
+ def residual_block(inputs, output_channel, stride, scope):
+ with tf.variable_scope(scope):
+ net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1')
+ net = batchnorm(net, FLAGS.is_training)
+ net = prelu_tf(net)
+ net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2')
+ net = batchnorm(net, FLAGS.is_training)
+ net = net + inputs
+
+ return net
+
+ # [optional] Put network on different GPU
+ with tf.device('/gpu:0'):
+ with tf.variable_scope('generator_unit', reuse=reuse):
+ # The input layer
+ with tf.variable_scope('input_stage'):
+ net = conv2(gen_inputs, 9, 64, 1, scope='conv')
+ net = prelu_tf(net)
+
+ stage1_output = net
+
+ # The residual block parts
+ for i in range(1, FLAGS.num_resblock+1 , 1):
+ name_scope = 'resblock_%d'%(i)
+ net = residual_block(net, 64, 1, name_scope)
+
+ with tf.variable_scope('resblock_output'):
+ net = conv2(net, 3, 64, 1, use_bias=False, scope='conv')
+ net = batchnorm(net, FLAGS.is_training)
+
+ net = net + stage1_output
+
+ with tf.device('/gpu:0'):
+ with tf.variable_scope('refine_stage1'):
+ net = conv2(net, 3, 256, 1, scope='conv')
+ net = batchnorm(net, FLAGS.is_training)
+ net = prelu_tf(net)
+
+ with tf.variable_scope('refine_stage2'):
+ net = conv2(net, 3, 256, 1, scope='conv')
+ net = batchnorm(net, FLAGS.is_training)
+ net = prelu_tf(net)
+
+ with tf.variable_scope('output_stage'):
+ net = conv2(net, 9, gen_output_channels, 1, scope='conv')
+ print(net.get_shape(), 'output_stage')
+
+ return net
+
+
+# Definition of the discriminator
+def discriminator(dis_inputs, FLAGS=None):
+ if FLAGS is None:
+ raise ValueError('No FLAGS is provided for generator')
+
+ # Define the discriminator block
+ def discriminator_block(inputs, output_channel, kernel_size, stride, scope):
+ with tf.variable_scope(scope):
+ net = conv2(inputs, kernel_size, output_channel, stride, use_bias=False, scope='conv1')
+ net = batchnorm(net, FLAGS.is_training)
+ net = lrelu(net, 0.2)
+
+ return net
+
+ with tf.device('/gpu:0'):
+ with tf.variable_scope('discriminator_unit'):
+ # The input layer
+ with tf.variable_scope('input_stage'):
+ net = conv2(dis_inputs, 3, 64, 1, scope='conv')
+ net = lrelu(net, 0.2)
+
+ # The discriminator block part
+ # block 1
+ net = discriminator_block(net, 64, 3, 2, 'disblock_1')
+
+ # block 2
+ net = discriminator_block(net, 128, 3, 1, 'disblock_2')
+
+ # block 3
+ net = discriminator_block(net, 128, 3, 2, 'disblock_3')
+
+ # block 4
+ net = discriminator_block(net, 256, 3, 1, 'disblock_4')
+
+ # block 5
+ net = discriminator_block(net, 256, 3, 2, 'disblock_5')
+
+ # block 6
+ net = discriminator_block(net, 512, 3, 1, 'disblock_6')
+
+ # block_7
+ net = discriminator_block(net, 512, 3, 2, 'disblock_7')
+
+ # The dense layer 1
+ with tf.variable_scope('dense_layer_1'):
+ net = slim.flatten(net)
+ net = denselayer(net, 1024)
+ net = lrelu(net, 0.2)
+
+ # The dense layer 2
+ with tf.variable_scope('dense_layer_2'):
+ net = denselayer(net, 1)
+ net = tf.nn.sigmoid(net)
+
+ return net
+
+
+# Define the feature extractor
+def VGG19_keras(type):
+ # Define the feature to extract according to the type of perceptual
+ if type == 'VGG54':
+ target_layer = 'block5_conv4'
+ elif type == 'VGG22':
+ target_layer = 'block2_conv2'
+ else:
+ raise NotImplementedError('Unknown perceptual type')
+ # Define the base model
+ base_model = VGG19(include_top=False, weights='imagenet')
+ extractor = Model(inputs=base_model.inputs, outputs=base_model.get_layer(target_layer).output)
+
+ return extractor
+
+
+def VGG19_slim(input, type, reuse, scope):
+ # Define the feature to extract according to the type of perceptual
+ if type == 'VGG54':
+ target_layer = scope + 'vgg_19/conv5/conv5_4'
+ elif type == 'VGG22':
+ target_layer = scope + 'vgg_19/conv2/conv2_2'
+ else:
+ raise NotImplementedError('Unknown perceptual type')
+ _, output = vgg_19(input, is_training=False, reuse=reuse)
+ output = output[target_layer]
+
+ return output
+
+
+# Define the whole network architecture
+def SRGAN(inputs, targets, FLAGS):
+ # Define the container of the parameter
+ Network = collections.namedtuple('Network', 'discrim_real_output, discrim_fake_output, discrim_loss, \
+ discrim_grads_and_vars, adversarial_loss, content_loss, gen_grads_and_vars, gen_output, train, global_step, \
+ learning_rate')
+
+ # Build the generator part
+ with tf.variable_scope('generator'):
+ output_channel = targets.get_shape().as_list()[-1]
+ gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS)
+ gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
+
+ # Build the fake discriminator
+ with tf.name_scope('fake_discriminator'):
+ with tf.variable_scope('discriminator', reuse=False):
+ discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS)
+
+ # Build the real discriminator
+ with tf.name_scope('real_discriminator'):
+ with tf.variable_scope('discriminator', reuse=True):
+ discrim_real_output = discriminator(targets, FLAGS=FLAGS)
+
+ # Use the VGG54 feature
+ if FLAGS.perceptual_mode == 'VGG54':
+ with tf.name_scope('vgg19_1') as scope:
+ extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
+ with tf.name_scope('vgg19_2') as scope:
+ extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)
+
+ elif FLAGS.perceptual_mode == 'VGG22':
+ extractor = VGG19_keras(FLAGS.perceptual_mode)
+ extracted_feature_gen = extractor.call(gen_output)
+ extracted_feature_target = extractor.call(targets)
+
+ elif FLAGS.perceptual_mode == 'MSE':
+ extracted_feature_gen = gen_output
+ extracted_feature_target = targets
+
+ else:
+ raise NotImplementedError('Unknown perceptual type')
+
+ # Calculating the generator loss
+ with tf.variable_scope('generator_loss'):
+ # Content loss
+ with tf.variable_scope('content_loss'):
+ # Compute the euclidean distance between the two features
+ # check=tf.equal(extracted_feature_gen, extracted_feature_target)
+ diff = extracted_feature_gen - extracted_feature_target
+ if FLAGS.perceptual_mode == 'MSE':
+ content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+ else:
+ content_loss = FLAGS.vgg_scaling*tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+
+ with tf.variable_scope('adversarial_loss'):
+ adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS))
+
+ gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss
+ print(adversarial_loss.get_shape())
+ print(content_loss.get_shape())
+
+ # Calculating the discriminator loss
+ with tf.variable_scope('discriminator_loss'):
+ discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS)
+ discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS)
+
+ discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss))
+
+ # Define the learning rate and global step
+ with tf.variable_scope('get_learning_rate_and_global_step'):
+ global_step = tf.contrib.framework.get_or_create_global_step()
+ learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair)
+ incr_global_step = tf.assign(global_step, global_step + 1)
+
+ with tf.variable_scope('dicriminator_train'):
+ discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
+ discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
+ discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars)
+ discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars)
+
+ with tf.variable_scope('generator_train'):
+ # Need to wait discriminator to perform train step
+ with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
+ gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
+ gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
+ gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
+ gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)
+
+ #[ToDo] If we do not use moving average on loss??
+ exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
+ update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss])
+
+ return Network(
+ discrim_real_output = discrim_real_output,
+ discrim_fake_output = discrim_fake_output,
+ discrim_loss = exp_averager.average(discrim_loss),
+ discrim_grads_and_vars = discrim_grads_and_vars,
+ adversarial_loss = exp_averager.average(adversarial_loss),
+ content_loss = exp_averager.average(content_loss),
+ gen_grads_and_vars = gen_grads_and_vars,
+ gen_output = gen_output,
+ train = tf.group(update_loss, incr_global_step, gen_train),
+ global_step = global_step,
+ learning_rate = learning_rate
+ )
+
+
+def SRResnet(inputs, targets, FLAGS):
+ # Define the container of the parameter
+ Network = collections.namedtuple('Network', 'content_loss, gen_grads_and_vars, gen_output, train, global_step, \
+ learning_rate')
+
+ # Build the generator part
+ with tf.variable_scope('generator'):
+ output_channel = targets.get_shape().as_list()[-1]
+ gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS)
+ gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size * 4, FLAGS.crop_size * 4, 3])
+
+ # Use the VGG54 feature
+ if FLAGS.perceptual_mode == 'VGG54':
+ with tf.name_scope('vgg19_1') as scope:
+ extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
+ with tf.name_scope('vgg19_2') as scope:
+ extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)
+
+ elif FLAGS.perceptual_mode == 'VGG22':
+ with tf.name_scope('vgg19_1') as scope:
+ extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
+ with tf.name_scope('vgg19_2') as scope:
+ extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)
+
+ elif FLAGS.perceptual_mode == 'MSE':
+ extracted_feature_gen = gen_output
+ extracted_feature_target = targets
+
+ else:
+ raise NotImplementedError('Unknown perceptual type')
+
+ # Calculating the generator loss
+ with tf.variable_scope('generator_loss'):
+ # Content loss
+ with tf.variable_scope('content_loss'):
+ # Compute the euclidean distance between the two features
+ # check=tf.equal(extracted_feature_gen, extracted_feature_target)
+ diff = extracted_feature_gen - extracted_feature_target
+ if FLAGS.perceptual_mode == 'MSE':
+ content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+ else:
+ content_loss = FLAGS.vgg_scaling * tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+
+ gen_loss = content_loss
+
+ # Define the learning rate and global step
+ with tf.variable_scope('get_learning_rate_and_global_step'):
+ global_step = tf.contrib.framework.get_or_create_global_step()
+ learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate,
+ staircase=FLAGS.stair)
+ incr_global_step = tf.assign(global_step, global_step + 1)
+
+ with tf.variable_scope('generator_train'):
+ # Need to wait discriminator to perform train step
+ with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
+ gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
+ gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
+ gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
+ gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)
+
+ # [ToDo] If we do not use moving average on loss??
+ exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
+ update_loss = exp_averager.apply([content_loss])
+
+ return Network(
+ content_loss=exp_averager.average(content_loss),
+ gen_grads_and_vars=gen_grads_and_vars,
+ gen_output=gen_output,
+ train=tf.group(update_loss, incr_global_step, gen_train),
+ global_step=global_step,
+ learning_rate=learning_rate
+ )
+
+
+# Use the denseNet version of the generator
+def SRResnet_dense(inputs, targets, FLAGS):
+ # Define the container of the parameter
+ Network = collections.namedtuple('Network', 'content_loss, gen_grads_and_vars, gen_output, train, global_step, \
+ learning_rate')
+
+ # Build the generator part
+ with tf.variable_scope('generator'):
+ output_channel = targets.get_shape().as_list()[-1]
+ gen_output = generatorDense(inputs, output_channel, reuse=False, FLAGS=FLAGS)
+ gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size * 4, FLAGS.crop_size * 4, 3])
+
+ # Use the VGG54 feature
+ if FLAGS.perceptual_mode == 'VGG54':
+ with tf.name_scope('vgg19_1') as scope:
+ extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
+ with tf.name_scope('vgg19_2') as scope:
+ extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)
+
+ elif FLAGS.perceptual_mode == 'VGG22':
+ with tf.name_scope('vgg19_1') as scope:
+ extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
+ with tf.name_scope('vgg19_2') as scope:
+ extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)
+
+ elif FLAGS.perceptual_mode == 'MSE':
+ extracted_feature_gen = gen_output
+ extracted_feature_target = targets
+
+ else:
+ raise NotImplementedError('Unknown perceptual type')
+
+ # Calculating the generator loss
+ with tf.variable_scope('generator_loss'):
+ # Content loss
+ with tf.variable_scope('content_loss'):
+ # Compute the euclidean distance between the two features
+ # check=tf.equal(extracted_feature_gen, extracted_feature_target)
+ diff = extracted_feature_gen - extracted_feature_target
+ if FLAGS.perceptual_mode == 'MSE':
+ content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+ else:
+ content_loss = FLAGS.vgg_scaling * tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+
+ gen_loss = content_loss
+
+ # Define the learning rate and global step
+ with tf.variable_scope('get_learning_rate_and_global_step'):
+ global_step = tf.contrib.framework.get_or_create_global_step()
+ learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate,
+ staircase=FLAGS.stair)
+ incr_global_step = tf.assign(global_step, global_step + 1)
+
+ with tf.variable_scope('generator_train'):
+ # Need to wait discriminator to perform train step
+ with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
+ gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
+ gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
+ gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
+ gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)
+
+ # [ToDo] If we do not use moving average on loss??
+ # exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
+ # update_loss = exp_averager.apply([content_loss])
+
+ return Network(
+ content_loss=content_loss,
+ gen_grads_and_vars=gen_grads_and_vars,
+ gen_output=gen_output,
+ train=tf.group(content_loss, incr_global_step, gen_train),
+ global_step=global_step,
+ learning_rate=learning_rate
+ )
+
+
+# Define the whole network architecture
+def network_denoise(inputs, targets, FLAGS):
+ # Define the container of the parameter
+ Network = collections.namedtuple('Network', 'discrim_real_output, discrim_fake_output, discrim_loss, \
+ discrim_grads_and_vars, adversarial_loss, content_loss, gen_grads_and_vars, gen_output, train, global_step, \
+ learning_rate')
+
+ # Build the generator part
+ with tf.variable_scope('generator'):
+ output_channel = targets.get_shape().as_list()[-1]
+ gen_output = generator_denoise(inputs, output_channel, reuse=False, FLAGS=FLAGS)
+ gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
+
+ # Build the fake discriminator
+ with tf.name_scope('fake_discriminator'):
+ with tf.variable_scope('discriminator', reuse=False):
+ discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS)
+
+ # Build the real discriminator
+ with tf.name_scope('real_discriminator'):
+ with tf.variable_scope('discriminator', reuse=True):
+ discrim_real_output = discriminator(targets, FLAGS=FLAGS)
+
+ # Use the VGG54 feature
+ if FLAGS.perceptual_mode == 'VGG54':
+ with tf.name_scope('vgg19_1') as scope:
+ extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
+ with tf.name_scope('vgg19_2') as scope:
+ extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)
+
+ elif FLAGS.perceptual_mode == 'VGG22':
+ with tf.name_scope('vgg19_1') as scope:
+ extracted_feature_gen = VGG19_slim(gen_output, FLAGS.perceptual_mode, reuse=False, scope=scope)
+ with tf.name_scope('vgg19_2') as scope:
+ extracted_feature_target = VGG19_slim(targets, FLAGS.perceptual_mode, reuse=True, scope=scope)
+
+ elif FLAGS.perceptual_mode == 'MSE':
+ extracted_feature_gen = gen_output
+ extracted_feature_target = targets
+
+ else:
+ raise NotImplementedError('Unknown perceptual type')
+
+ # Calculating the generator loss
+ with tf.variable_scope('generator_loss'):
+ # Content loss
+ with tf.variable_scope('content_loss'):
+ # Compute the euclidean distance between the two features
+ if (FLAGS.perceptual_mode == 'VGG54') or (FLAGS.perceptual_mode == 'VGG22'):
+ diff = extracted_feature_gen - extracted_feature_target
+ content_loss = FLAGS.vgg_scaling * tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+ elif FLAGS.perceptual_mode == 'MSE':
+ diff = targets - gen_output
+ content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
+ else:
+ raise NotImplementedError('Unknown perceptual type')
+
+ with tf.variable_scope('adversarial_loss'):
+ adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS))
+
+ gen_loss = content_loss + (FLAGS.ratio)*adversarial_loss
+
+ # Calculating the discriminator loss
+ with tf.variable_scope('discriminator_loss'):
+ discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS)
+ discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS)
+
+ discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss))
+
+ # Define the learning rate and global step
+ with tf.variable_scope('get_learning_rate_and_global_step'):
+ global_step = tf.contrib.framework.get_or_create_global_step()
+ learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=FLAGS.stair)
+ incr_global_step = tf.assign(global_step, global_step + 1)
+
+ with tf.variable_scope('dicriminator_train'):
+ discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
+ discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
+ discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars)
+ discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars)
+
+ with tf.variable_scope('generator_train'):
+ # Need to wait discriminator to perform train step
+ with tf.control_dependencies([discrim_train]+ tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
+ gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
+ gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
+ gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
+ gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)
+
+ #[ToDo] If we do not use moving average on loss??
+ exp_averager = tf.train.ExponentialMovingAverage(decay=0.99)
+ update_loss = exp_averager.apply([discrim_loss, adversarial_loss, content_loss])
+
+ return Network(
+ discrim_real_output = discrim_real_output,
+ discrim_fake_output = discrim_fake_output,
+ discrim_loss = exp_averager.average(discrim_loss),
+ discrim_grads_and_vars = discrim_grads_and_vars,
+ adversarial_loss = exp_averager.average(adversarial_loss),
+ content_loss = exp_averager.average(content_loss),
+ gen_grads_and_vars = gen_grads_and_vars,
+ gen_output = gen_output,
+ train = tf.group(update_loss, incr_global_step, gen_train),
+ global_step = global_step,
+ learning_rate = learning_rate
+ )
+
+
+def save_images(fetches, FLAGS, step=None):
+ image_dir = os.path.join(FLAGS.output_dir, "images")
+ if not os.path.exists(image_dir):
+ os.makedirs(image_dir)
+
+ filesets = []
+ in_path = fetches['path_LR']
+ name, _ = os.path.splitext(os.path.basename(str(in_path)))
+ fileset = {"name": name, "step": step}
+ for kind in ["inputs", "outputs", "targets"]:
+ filename = name + "-" + kind + ".png"
+ if step is not None:
+ filename = "%08d-%s" % (step, filename)
+ fileset[kind] = filename
+ out_path = os.path.join(image_dir, filename)
+ contents = fetches[kind][0]
+ with open(out_path, "wb") as f:
+ f.write(contents)
+ filesets.append(fileset)
+ return filesets
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/model_dense.py b/lib/model_dense.py
new file mode 100644
index 0000000..38c51c0
--- /dev/null
+++ b/lib/model_dense.py
@@ -0,0 +1,95 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from lib.ops import *
+import collections
+import os
+import math
+import scipy.misc as sic
+import numpy as np
+
+
+# The dense layer
+def denseConvlayer(layer_inputs, bottleneck_scale, growth_rate, is_training):
+ # Build the bottleneck operation
+ net = layer_inputs
+ net_temp = tf.identity(net)
+ net = batchnorm(net, is_training)
+ net = prelu_tf(net, name='Prelu_1')
+ net = conv2(net, kernel=1, output_channel=bottleneck_scale*growth_rate, stride=1, use_bias=False, scope='conv1x1')
+ net = batchnorm(net, is_training)
+ net = prelu_tf(net, name='Prelu_2')
+ net = conv2(net, kernel=3, output_channel=growth_rate, stride=1, use_bias=False, scope='conv3x3')
+
+ # Concatenate the processed feature to the feature
+ net = tf.concat([net_temp, net], axis=3)
+
+ return net
+
+
+# The transition layer
+def transitionLayer(layer_inputs, output_channel, is_training):
+ net = layer_inputs
+ net = batchnorm(net, is_training)
+ net = prelu_tf(net)
+ net = conv2(net, 1, output_channel, stride=1, use_bias=False, scope='conv1x1')
+
+ return net
+
+
+# The dense block
+def denseBlock(block_inputs, num_layers, bottleneck_scale, growth_rate, FLAGS):
+ # Build each layer consecutively
+ net = block_inputs
+ for i in range(num_layers):
+ with tf.variable_scope('dense_conv_layer%d'%(i+1)):
+ net = denseConvlayer(net, bottleneck_scale, growth_rate, FLAGS.is_training)
+
+ return net
+
+
+# Here we define the dense block version generator
+def generatorDense(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
+ # Check the flag
+ if FLAGS is None:
+ raise ValueError('No FLAGS is provided for generator')
+
+ # The main netowrk
+ with tf.variable_scope('generator_unit', reuse=reuse):
+ # The input stage
+ with tf.variable_scope('input_stage'):
+ net = conv2(gen_inputs, 9, 64, 1, scope='conv')
+ net = prelu_tf(net)
+
+ # The dense block part
+ # Define the denseblock configuration
+ layer_per_block = 16
+ bottleneck_scale = 4
+ growth_rate = 12
+ transition_output_channel = 128
+ with tf.variable_scope('denseBlock_1'):
+ net = denseBlock(net, layer_per_block, bottleneck_scale, growth_rate, FLAGS)
+
+ with tf.variable_scope('transition_layer_1'):
+ net = transitionLayer(net, transition_output_channel, FLAGS.is_training)
+
+ with tf.variable_scope('subpixelconv_stage1'):
+ net = conv2(net, 3, 256, 1, scope='conv')
+ net = pixelShuffler(net, scale=2)
+ net = prelu_tf(net)
+
+ with tf.variable_scope('subpixelconv_stage2'):
+ net = conv2(net, 3, 256, 1, scope='conv')
+ net = pixelShuffler(net, scale=2)
+ net = prelu_tf(net)
+
+ with tf.variable_scope('output_stage'):
+ net = conv2(net, 9, gen_output_channels, 1, scope='conv')
+
+ return net
+
+
+
+
diff --git a/lib/ops.py b/lib/ops.py
new file mode 100644
index 0000000..a0ee165
--- /dev/null
+++ b/lib/ops.py
@@ -0,0 +1,231 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+import tensorflow.contrib.keras as keras
+import pdb
+
+
+def preprocess(image):
+ with tf.name_scope("preprocess"):
+ # [0, 1] => [-1, 1]
+ return image * 2 - 1
+
+
+def deprocess(image):
+ with tf.name_scope("deprocess"):
+ # [-1, 1] => [0, 1]
+ return (image + 1) / 2
+
+
+def preprocessLR(image):
+ with tf.name_scope("preprocessLR"):
+ return tf.identity(image)
+
+
+def deprocessLR(image):
+ with tf.name_scope("deprocessLR"):
+ return tf.identity(image)
+
+
+# Define the convolution building block
+def conv2(batch_input, kernel=3, output_channel=64, stride=1, use_bias=True, scope='conv'):
+ # kernel: An integer specifying the width and height of the 2D convolution window
+ with tf.variable_scope(scope):
+ if use_bias:
+ return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NHWC',
+ activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer())
+ else:
+ return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NHWC',
+ activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer(),
+ biases_initializer=None)
+
+
+def conv2_NCHW(batch_input, kernel=3, output_channel=64, stride=1, use_bias=True, scope='conv_NCHW'):
+ # Use NCWH to speed up the inference
+ # kernel: list of 2 integer specifying the width and height of the 2D convolution window
+ with tf.variable_scope(scope):
+ if use_bias:
+ return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NCWH',
+ activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer())
+ else:
+ return slim.conv2d(batch_input, output_channel, [kernel, kernel], stride, 'SAME', data_format='NCWH',
+ activation_fn=None, weights_initializer=tf.contrib.layers.xavier_initializer(),
+ biases_initializer=None)
+
+# [ToDo]: Write a code to test this function, check the reuse of this model
+# Define our Prelu
+def prelu(inputs, input_shape):
+ prelu_unit = keras.layers.PReLU()
+ prelu_unit.build(input_shape)
+ return prelu_unit.call(inputs)
+
+
+# Define our tensorflow version PRelu
+def prelu_tf(inputs, name='Prelu'):
+ with tf.variable_scope(name):
+ alphas = tf.get_variable('alpha', inputs.get_shape()[-1], initializer=tf.zeros_initializer(), dtype=tf.float32)
+ pos = tf.nn.relu(inputs)
+ neg = alphas * (inputs - abs(inputs)) * 0.5
+
+ return pos + neg
+
+# Define our Lrelu
+def lrelu(inputs, alpha):
+ return keras.layers.LeakyReLU(alpha=alpha).call(inputs)
+
+
+def batchnorm(inputs, is_training):
+ return slim.batch_norm(inputs, decay=0.9, epsilon=0.001, updates_collections=tf.GraphKeys.UPDATE_OPS,
+ scale=False, fused=True, is_training=is_training)
+
+
+# Our dense layer
+def denselayer(inputs, output_size):
+ output = tf.layers.dense(inputs, output_size, activation=None, kernel_initializer=tf.contrib.layers.xavier_initializer())
+ return output
+
+
+# The implementation of PixelShuffler
+def pixelShuffler(inputs, scale=2):
+ size = tf.shape(inputs)
+ batch_size = size[0]
+ h = size[1]
+ w = size[2]
+ c = inputs.get_shape().as_list()[-1]
+
+ # Get the target channel size
+ channel_target = c // (scale * scale)
+ channel_factor = c // channel_target
+
+ shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale]
+ shape_2 = [batch_size, h * scale, w * scale, 1]
+
+ # Reshape and transpose for periodic shuffling for each channel
+ input_split = tf.split(inputs, channel_target, axis=3)
+ output = tf.concat([phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3)
+
+ return output
+
+
+def phaseShift(inputs, scale, shape_1, shape_2):
+ # Tackle the condition when the batch is None
+ X = tf.reshape(inputs, shape_1)
+ X = tf.transpose(X, [0, 1, 3, 2, 4])
+
+ return tf.reshape(X, shape_2)
+
+
+# The random flip operation used for loading examples
+def random_flip(input, decision):
+ f1 = tf.identity(input)
+ f2 = tf.image.flip_left_right(input)
+ output = tf.cond(tf.less(decision, 0.5), lambda: f2, lambda: f1)
+
+ return output
+
+
+# The operation used to print out the configuration
+def print_configuration_op(FLAGS):
+ print('[Configurations]:')
+ a = FLAGS.mode
+ #pdb.set_trace()
+ for name, value in FLAGS.__flags.items():
+ if type(value) == float:
+ print('\t%s: %f'%(name, value))
+ elif type(value) == int:
+ print('\t%s: %d'%(name, value))
+ elif type(value) == str:
+ print('\t%s: %s'%(name, value))
+ elif type(value) == bool:
+ print('\t%s: %s'%(name, value))
+ else:
+ print('\t%s: %s' % (name, value))
+
+ print('End of configuration')
+
+
+def compute_psnr(ref, target):
+ ref = tf.cast(ref, tf.float32)
+ target = tf.cast(target, tf.float32)
+ diff = tf.subtract(target, ref)
+ sqr = tf.multiply(diff, diff)
+ err = tf.reduce_sum(sqr)
+ v = tf.shape(diff)[0] * tf.shape(diff)[1] * tf.shape(diff)[2] * tf.shape(diff)[3]
+ mse = tf.div(err, tf.cast(v, tf.float32))
+ c10 = tf.constant(10, tf.float32)
+ c255_2 = tf.multiply(tf.constant(255,tf.float32),tf.constant(255, tf.float32))
+ psnr = tf.multiply(c10, tf.div(tf.log(tf.div(c255_2, mse)), tf.log(c10)))
+
+ return psnr
+
+
+# VGG19 component
+def vgg_arg_scope(weight_decay=0.0005):
+ """Defines the VGG arg scope.
+ Args:
+ weight_decay: The l2 regularization coefficient.
+ Returns:
+ An arg_scope.
+ """
+ with slim.arg_scope([slim.conv2d, slim.fully_connected],
+ activation_fn=tf.nn.relu,
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ biases_initializer=tf.zeros_initializer()):
+ with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
+ return arg_sc
+
+
+# VGG19 net
+def vgg_19(inputs,
+ num_classes=1000,
+ is_training=False,
+ dropout_keep_prob=0.5,
+ spatial_squeeze=True,
+ scope='vgg_19',
+ reuse = False,
+ fc_conv_padding='VALID'):
+ """Oxford Net VGG 19-Layers version E Example.
+ Note: All the fully_connected layers have been transformed to conv2d layers.
+ To use in classification mode, resize input to 224x224.
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels].
+ num_classes: number of predicted classes.
+ is_training: whether or not the model is being trained.
+ dropout_keep_prob: the probability that activations are kept in the dropout
+ layers during training.
+ spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+ outputs. Useful to remove unnecessary dimensions for classification.
+ scope: Optional scope for the variables.
+ fc_conv_padding: the type of padding to use for the fully connected layer
+ that is implemented as a convolutional layer. Use 'SAME' padding if you
+ are applying the network in a fully convolutional manner and want to
+ get a prediction map downsampled by a factor of 32 as an output. Otherwise,
+ the output prediction map will be (input / 32) - 6 in case of 'VALID' padding.
+ Returns:
+ the last op containing the log predictions and end_points dict.
+ """
+ with tf.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc:
+ end_points_collection = sc.name + '_end_points'
+ # Collect outputs for conv2d, fully_connected and max_pool2d.
+ with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+ outputs_collections=end_points_collection):
+ net = slim.repeat(inputs, 2, slim.conv2d, 64, 3, scope='conv1', reuse=reuse)
+ net = slim.max_pool2d(net, [2, 2], scope='pool1')
+ net = slim.repeat(net, 2, slim.conv2d, 128, 3, scope='conv2',reuse=reuse)
+ net = slim.max_pool2d(net, [2, 2], scope='pool2')
+ net = slim.repeat(net, 4, slim.conv2d, 256, 3, scope='conv3', reuse=reuse)
+ net = slim.max_pool2d(net, [2, 2], scope='pool3')
+ net = slim.repeat(net, 4, slim.conv2d, 512, 3, scope='conv4',reuse=reuse)
+ net = slim.max_pool2d(net, [2, 2], scope='pool4')
+ net = slim.repeat(net, 4, slim.conv2d, 512, 3, scope='conv5',reuse=reuse)
+ net = slim.max_pool2d(net, [2, 2], scope='pool5')
+ # Use conv2d instead of fully_connected layers.
+ # Convert end_points_collection into a end_point dict.
+ end_points = slim.utils.convert_collection_to_dict(end_points_collection)
+
+ return net, end_points
+vgg_19.default_image_size = 224
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..06e3d3e
--- /dev/null
+++ b/main.py
@@ -0,0 +1,343 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+import tensorflow.contrib.slim as slim
+import tensorflow.contrib.keras as keras
+import os
+from lib.model import data_loader2, generator, SRGAN, test_data_loader, save_images, SRResnet, \
+ SRResnet_dense, generator_denoise
+from lib.ops import *
+import math
+import time
+import numpy as np
+# import ipdb
+
+Flags = tf.app.flags
+
+# The system parameter
+Flags.DEFINE_string('output_dir', None, 'The output directory of the checkpoint')
+Flags.DEFINE_string('summary_dir', None, 'The dirctory to output the summary')
+Flags.DEFINE_string('mode', 'train', 'The mode of the model train, test, inference')
+Flags.DEFINE_string('checkpoint', None, 'The checkpoint needed to be restored')
+Flags.DEFINE_boolean('pre_trained_model', False, 'Whether using the pretrained model')
+Flags.DEFINE_boolean('is_training', True, 'Whether we are in the training phase')
+Flags.DEFINE_string('vgg_ckpt', './vgg19/vgg_19.ckpt', 'checkpoint file for the vgg19')
+Flags.DEFINE_string('task', None, 'The task: SRGAN, SRResnet')
+Flags.DEFINE_string('generator_type', None, 'The type of the generator: "original", "denseNet".')
+# The data preparing operation
+Flags.DEFINE_integer('batch_size', 128, 'Batch size of input')
+Flags.DEFINE_string('input_dir_LR', None, 'The directory of the input resolution input data')
+Flags.DEFINE_string('input_dir_HR', None, 'The directory of the high resolution input data')
+Flags.DEFINE_integer('train_image_width', 154, 'The width of the training image (low resolution one')
+Flags.DEFINE_integer('train_image_height', 102, 'The height of the traing image (low resolution one')
+Flags.DEFINE_boolean('flip', True, 'Whether data augmentation is applied')
+Flags.DEFINE_boolean('random_crop', True, 'Whether perform the random crop')
+Flags.DEFINE_integer('crop_size', 24, 'The crop size of the training image')
+Flags.DEFINE_integer('name_queue_capacity', 512, 'The capacity of the filename queue')
+Flags.DEFINE_integer('image_queue_capacity', 1024, 'The capacity of the image queue')
+Flags.DEFINE_integer('queue_thread', 20, 'The threads of the queue')
+# Generator configuration
+Flags.DEFINE_integer('num_resblock', 16, 'How many residual blocks are there in the generator')
+# The content loss parameter
+Flags.DEFINE_string('perceptual_mode', 'VGG54', 'The type of feature used in perceptual loss')
+Flags.DEFINE_float('EPS', 1e-12, 'The eps added to prevent nan')
+Flags.DEFINE_float('ratio', 0.001, 'The ratio between content loss and adversarial loss')
+Flags.DEFINE_float('vgg_scaling', 0.0061, 'The scaling factor for the perceptual loss if using vgg perceptual loss')
+# The training parameters
+Flags.DEFINE_float('learning_rate', 0.0001, 'The learning rate for the network')
+Flags.DEFINE_integer('decay_step', 20000, 'The steps needed to decay the learning rate')
+Flags.DEFINE_float('decay_rate', 0.1, 'The decay rate of each decay step')
+Flags.DEFINE_boolean('stair', False, 'Whether perform staircase decay')
+Flags.DEFINE_float('beta', 0.9, 'The beta1 parameter for the Adam optimizer')
+Flags.DEFINE_integer('max_epoch', None, 'The max epoch for the training')
+Flags.DEFINE_integer('max_iter', 1000000, 'The max iteration of the training')
+Flags.DEFINE_integer('display_freq', 20, 'The diplay frequency of the training process')
+Flags.DEFINE_integer('summary_freq', 100, 'The frequency of writing summary')
+Flags.DEFINE_integer('save_freq', 10000, 'The frequency of saving images')
+
+
+FLAGS = Flags.FLAGS
+
+# Print the configuration of the model
+print_configuration_op(FLAGS)
+
+# Check the output_dir is given
+if FLAGS.output_dir == None:
+ raise ValueError('The output directory is needed')
+
+# Check the output directory to save the checkpoint
+if not os.path.exists(FLAGS.output_dir):
+ os.mkdir(FLAGS.output_dir)
+
+# Check the summary directory to save the event
+if not os.path.exists(FLAGS.summary_dir):
+ os.mkdir(FLAGS.summary_dir)
+
+# The testing mode
+if FLAGS.mode == 'test':
+ # Check the checkpoint
+ if FLAGS.checkpoint == None:
+ raise ValueError('The checkpoint is needed to performing the test.')
+
+ # In the testing time, no flip and crop is needed
+ if FLAGS.flip == True:
+ FLAGS.flip = False
+
+ if FLAGS.crop_size != None:
+ FLAGS.crop_size = None
+
+ # Declare the test data reader
+ test_data = test_data_loader(FLAGS)
+
+ inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw')
+ targets_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='targets_raw')
+ path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')
+ path_HR = tf.placeholder(tf.string, shape=[], name='path_HR')
+
+ with tf.variable_scope('generator'):
+ if FLAGS.task == 'denoise':
+ gen_output = generator_denoise(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
+ elif FLAGS.task == 'SRGAN':
+ gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
+
+ print('Finish building the network')
+
+ with tf.name_scope('convert_image'):
+ # Deprocess the images outputed from the model
+ inputs = deprocessLR(inputs_raw)
+ targets = deprocess(targets_raw)
+ outputs = deprocess(gen_output)
+
+ # Convert back to uint8
+ converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True)
+ converted_targets = tf.image.convert_image_dtype(targets, dtype=tf.uint8, saturate=True)
+ converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True)
+
+ with tf.name_scope('encode_image'):
+ save_fetch = {
+ "path_LR": path_LR,
+ "path_HR": path_HR,
+ "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name='input_pngs'),
+ "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name='output_pngs'),
+ "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name='target_pngs')
+ }
+
+ # Define the weight initiallizer
+ var_list2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
+ weight_initiallizer = tf.train.Saver(var_list2)
+
+ # Define the initialization operation
+ init_op = tf.global_variables_initializer()
+
+ config = tf.ConfigProto()
+ config.gpu_options.allow_growth = True
+ with tf.Session(config=config) as sess:
+ # Load the pretrained model
+ print('Loading weights from the pre-trained model')
+ weight_initiallizer.restore(sess, FLAGS.checkpoint)
+
+ max_iter = len(test_data.inputs)
+ print('Evaluation starts!!')
+ for i in range(max_iter):
+ input_im = np.array([test_data.inputs[i]]).astype(np.float32)
+ target_im = np.array([test_data.targets[i]]).astype(np.float32)
+ path_lr = test_data.paths_LR[i]
+ path_hr = test_data.paths_HR[i]
+ results = sess.run(save_fetch, feed_dict={inputs_raw: input_im, targets_raw: target_im,
+ path_LR: path_lr, path_HR: path_hr})
+ filesets = save_images(results, FLAGS)
+ for i, f in enumerate(filesets):
+ print('evaluate image', f['name'])
+
+# The training mode
+elif FLAGS.mode == 'train':
+ # Load data for training and testing
+ # ToDo Add online noise adding and downscaling
+ data = data_loader2(FLAGS)
+ print('Data count = %d' % (data.image_count))
+
+ # Connect to the network
+ if FLAGS.task == 'SRGAN':
+ if FLAGS.generator_type == 'original':
+ Net = SRGAN(data.inputs, data.targets, FLAGS)
+ # elif FLAGS.generator_type == 'denseNet':
+ # Net = SRResnet_dense(data.inputs, data.targets, FLAGS)
+ else:
+ raise ValueError('Unknown generator type')
+ elif FLAGS.task =='SRResnet':
+ if FLAGS.generator_type == 'original':
+ Net = SRResnet(data.inputs, data.targets, FLAGS)
+ elif FLAGS.generator_type == 'denseNet':
+ Net = SRResnet_dense(data.inputs, data.targets, FLAGS)
+ else:
+ raise ValueError('Unknown generator type')
+
+ print('Finish building the network!!')
+
+ # Convert the images output from the network
+ with tf.name_scope('convert_image'):
+ # Deprocess the images outputed from the model
+ inputs = deprocessLR(data.inputs)
+ targets = deprocess(data.targets)
+ outputs = deprocess(Net.gen_output)
+
+ # Convert back to uint8
+ converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True)
+ converted_targets = tf.image.convert_image_dtype(targets, dtype=tf.uint8, saturate=True)
+ converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True)
+
+ # Compute PSNR
+ with tf.name_scope("compute_psnr"):
+ psnr = compute_psnr(converted_targets, converted_outputs)
+
+ # Add image summaries
+ with tf.name_scope('inputs_summary'):
+ tf.summary.image('input_summary', converted_inputs)
+
+ with tf.name_scope('targets_summary'):
+ tf.summary.image('target_summary', converted_targets)
+
+ with tf.name_scope('outputs_summary'):
+ tf.summary.image('outputs_summary', converted_outputs)
+
+ # Add scalar summary
+ if FLAGS.task == 'SRGAN':
+ tf.summary.scalar('discriminator_loss', Net.discrim_loss)
+ tf.summary.scalar('adversarial_loss', Net.adversarial_loss)
+ tf.summary.scalar('content_loss', Net.content_loss)
+ tf.summary.scalar('generator_loss', Net.content_loss + FLAGS.ratio*Net.adversarial_loss)
+ tf.summary.scalar('PSNR', psnr)
+ tf.summary.scalar('learning_rate', Net.learning_rate)
+ elif FLAGS.task == 'SRResnet':
+ tf.summary.scalar('content_loss', Net.content_loss)
+ tf.summary.scalar('generator_loss', Net.content_loss)
+ tf.summary.scalar('PSNR', psnr)
+ tf.summary.scalar('learning_rate', Net.learning_rate)
+
+
+ # Define the saver and weight initiallizer
+ saver = tf.train.Saver(max_to_keep=10)
+
+ var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
+ if FLAGS.task == 'SRGAN':
+ #var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + \
+ # tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
+ var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
+ elif FLAGS.task == 'SRResnet':
+ var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
+
+ weight_initiallizer = tf.train.Saver(var_list2)
+
+ # When using MSE loss, no need to restore the vgg net
+ if not FLAGS.perceptual_mode == 'MSE':
+ vgg_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='vgg_19')
+ vgg_restore = tf.train.Saver(vgg_var_list)
+ # print(vgg_var_list)
+ # init_op = tf.global_variables_initializer()
+ # merged_summary = tf.summary.merge_all()
+ # summary_op = tf.summary.FileWriter(logdir=FLAGS.summary_dir)
+ # Start the session
+ config = tf.ConfigProto()
+ config.gpu_options.allow_growth = True
+ # with tf.Session(config=config) as sess:
+ # Use superviser to coordinate all queue and summary writer
+ sv = tf.train.Supervisor(logdir=FLAGS.summary_dir, save_summaries_secs=0, saver=None)
+ with sv.managed_session(config=config) as sess:
+ if (FLAGS.checkpoint is not None) and (FLAGS.pre_trained_model is False):
+ print('Loading model from the checkpoint...')
+ checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
+ saver.restore(sess, checkpoint)
+
+ elif (FLAGS.checkpoint is not None) and (FLAGS.pre_trained_model is True):
+ print('Loading weights from the pre-trained model')
+ weight_initiallizer.restore(sess, FLAGS.checkpoint)
+
+ if not FLAGS.perceptual_mode == 'MSE':
+ vgg_restore.restore(sess, FLAGS.vgg_ckpt)
+ print('VGG19 restored successfully!!')
+
+ # Coordinate from multiple threads
+ # coord = tf.train.Coordinator()
+ # threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ # Performing the training
+ if FLAGS.max_epoch is None:
+ if FLAGS.max_iter is None:
+ raise ValueError('one of max_epoch or max_iter should be provided')
+ else:
+ max_iter = FLAGS.max_iter
+ else:
+ max_iter = FLAGS.max_epoch * data.steps_per_epoch
+
+ print('Optimization starts!!!')
+ start = time.time()
+ for step in range(max_iter):
+ fetches = {
+ "train": Net.train,
+ "global_step": sv.global_step,
+ }
+
+ if ((step+1) % FLAGS.display_freq) == 0:
+ if FLAGS.task == 'SRGAN':
+ fetches["discrim_loss"] = Net.discrim_loss
+ fetches["adversarial_loss"] = Net.adversarial_loss
+ fetches["content_loss"] = Net.content_loss
+ fetches["PSNR"] = psnr
+ fetches["learning_rate"] = Net.learning_rate
+ fetches["global_step"] = Net.global_step
+ elif FLAGS.task == 'SRResnet':
+ fetches["content_loss"] = Net.content_loss
+ fetches["PSNR"] = psnr
+ fetches["learning_rate"] = Net.learning_rate
+ fetches["global_step"] = Net.global_step
+
+ if ((step+1) % FLAGS.summary_freq) == 0:
+ # fetches["summary"] = merged_summary
+ fetches["summary"] = sv.summary_op
+
+ results = sess.run(fetches)
+
+ if ((step + 1) % FLAGS.summary_freq) == 0:
+ print('Recording summary!!')
+ # summary_op.add_summary(results['summary'], results['global_step'])
+ sv.summary_writer.add_summary(results['summary'], results['global_step'])
+
+ if ((step + 1) % FLAGS.display_freq) == 0:
+ train_epoch = math.ceil(results["global_step"] / data.steps_per_epoch)
+ train_step = (results["global_step"] - 1) % data.steps_per_epoch + 1
+ rate = (step + 1) * FLAGS.batch_size / (time.time() - start)
+ remaining = (max_iter - step) * FLAGS.batch_size / rate
+ print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60))
+ if FLAGS.task == 'SRGAN':
+ print("global_step", results["global_step"])
+ print("PSNR", results["PSNR"])
+ print("discrim_loss", results["discrim_loss"])
+ print("adversarial_loss", results["adversarial_loss"])
+ print("content_loss", results["content_loss"])
+ print("learning_rate", results['learning_rate'])
+ elif FLAGS.task == 'SRResnet':
+ print("global_step", results["global_step"])
+ print("PSNR", results["PSNR"])
+ print("content_loss", results["content_loss"])
+ print("learning_rate", results['learning_rate'])
+ #print('check', results['check'])
+
+
+ if ((step +1) % FLAGS.save_freq) == 0:
+ print('Save the checkpoint')
+ saver.save(sess, os.path.join(FLAGS.output_dir, 'model'), global_step=sv.global_step)
+
+ # coord.request_stop()
+ # coord.join(threads)
+
+ print('Optimization done!!!!!!!!!!!!')
+
+
+
+
+
+
+
+
diff --git a/test_SRGAN.sh b/test_SRGAN.sh
new file mode 100644
index 0000000..7c39ae7
--- /dev/null
+++ b/test_SRGAN.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+CUDA_VISIBLE_DEVICES=0 python main.py \
+ --output_dir ./val_result/SRGAN_VGG54/ \
+ --summary_dir ./val_result/SRGAN_VGG54/log/ \
+ --mode test \
+ --is_training False \
+ --task SRGAN \
+ --batch_size 16 \
+ --input_dir_LR ./data/Set14_LR/ \
+ --input_dir_HR ./data/Set14_HR/ \
+ --num_resblock 16 \
+ --perceptual_mode VGG54 \
+ --pre_trained_model True \
+ --checkpoint ./experiment_SRGAN_VGG54/model-200000
+
diff --git a/tool/__init__.py b/tool/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tool/convertPNG.py b/tool/convertPNG.py
new file mode 100644
index 0000000..e69de29
diff --git a/tool/resizeImage.py b/tool/resizeImage.py
new file mode 100644
index 0000000..07584b4
--- /dev/null
+++ b/tool/resizeImage.py
@@ -0,0 +1,39 @@
+from PIL import Image
+import os
+from multiprocessing import Pool
+
+# Define the input and output image
+input_dir = '../data/RAISE_HR/'
+output_dir = '../data/RAISE_LR/'
+scale = 4.
+
+if not os.path.exists(output_dir):
+ os.mkdir(output_dir)
+
+image_list = os.listdir(input_dir)
+image_list = [os.path.join(input_dir, _) for _ in image_list]
+
+# Define the pool function
+def downscale(name):
+ print(name)
+ with Image.open(name) as im:
+ w, h = im.size
+ w_new = int(w / scale)
+ h_new = int(h / scale)
+ im_new = im.resize((w_new, h_new), Image.ANTIALIAS)
+
+ save_name = os.path.join(output_dir, name.split('/')[-1])
+ im_new.save(save_name)
+
+p = Pool(5)
+p.map(downscale, image_list)
+# for name in image_list:
+# print name
+# with Image.open(name) as im:
+# w, h = im.size
+# w_new = int(w / scale)
+# h_new = int(w / scale)
+# im.resize((w_new, h_new), Image.ANTIALIAS)
+#
+# save_name = os.path.join(output_dir, name.split('/')[-1].split('-0')[0]+'.png')
+# im.save(save_name)
\ No newline at end of file
diff --git a/train_SRGAN.sh b/train_SRGAN.sh
new file mode 100644
index 0000000..6544023
--- /dev/null
+++ b/train_SRGAN.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+CUDA_VISIBLE_DEVICES=0 python main.py \
+ --output_dir ./experiment_SRGAN_VGG54/ \
+ --summary_dir ./experiment_SRGAN_VGG54/log/ \
+ --mode train \
+ --is_training True \
+ --task SRGAN \
+ --batch_size 16 \
+ --flip True \
+ --random_crop True \
+ --crop_size 24 \
+ --input_dir_LR ./data/RAISE_LR/ \
+ --input_dir_HR ./data/RAISE_HR/ \
+ --num_resblock 16 \
+ --perceptual_mode VGG54 \
+ --name_queue_capacity 4096 \
+ --image_queue_capacity 4096 \
+ --ratio 0.001 \
+ --learning_rate 0.0001 \
+ --decay_step 100000 \
+ --decay_rate 0.1 \
+ --stair True \
+ --beta 0.9 \
+ --max_iter 200000 \
+ --queue_thread 10 \
+ --vgg_scaling 0.0061 \
+ --pre_trained_model True \
+ --checkpoint ./experiment_SRGAN_MSE/model-500000
+
diff --git a/train_SRResnet.sh b/train_SRResnet.sh
new file mode 100644
index 0000000..30e6d82
--- /dev/null
+++ b/train_SRResnet.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+CUDA_VISIBLE_DEVICES=0 python main.py \
+ --output_dir ./experiment_SRResnet_dense/ \
+ --summary_dir ./experiment_SRResnet_dense/log/ \
+ --mode train \
+ --is_training True \
+ --task SRResnet \
+ --generator_type denseNet \
+ --batch_size 16 \
+ --flip True \
+ --random_crop True \
+ --crop_size 24 \
+ --input_dir_LR ./data/RAISE_LR/ \
+ --input_dir_HR ./data/RAISE_HR/ \
+ --num_resblock 16 \
+ --name_queue_capacity 4096 \
+ --image_queue_capacity 4096 \
+ --perceptual_mode MSE \
+ --queue_thread 12 \
+ --ratio 0.001 \
+ --learning_rate 0.0001 \
+ --decay_step 400000 \
+ --decay_rate 0.1 \
+ --stair False \
+ --beta 0.9 \
+ --max_iter 1000000 \
+ --save_freq 20000
+