|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import scipy.misc |
| 4 | +import numpy as np |
| 5 | +import checkGPU |
| 6 | +from model import pix2pix |
| 7 | +import tensorflow as tf |
| 8 | + |
| 9 | +parser = argparse.ArgumentParser(description='') |
| 10 | +parser.add_argument('--dataset_name', dest='dataset_name', default='facades', help='name of the dataset') |
| 11 | +parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch') |
| 12 | +parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch') |
| 13 | +parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train') |
| 14 | +parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size') |
| 15 | +parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size') |
| 16 | +parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer') |
| 17 | +parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer') |
| 18 | +parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels') |
| 19 | +parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels') |
| 20 | +parser.add_argument('--niter', dest='niter', type=int, default=200, help='# of iter at starting learning rate') |
| 21 | +parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam') |
| 22 | +parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') |
| 23 | +parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation') |
| 24 | +parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA') |
| 25 | +parser.add_argument('--phase', dest='phase', default='train', help='train, test') |
| 26 | +parser.add_argument('--save_epoch_freq', dest='save_epoch_freq', type=int, default=50, help='save a model every save_epoch_freq epochs (does not overwrite previously saved models)') |
| 27 | +parser.add_argument('--save_latest_freq', dest='save_latest_freq', type=int, default=5000, help='save the latest model every latest_freq sgd iterations (overwrites the previous latest model)') |
| 28 | +parser.add_argument('--print_freq', dest='print_freq', type=int, default=50, help='print the debug information every print_freq iterations') |
| 29 | +parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false') |
| 30 | +parser.add_argument('--serial_batches', dest='serial_batches', type=bool, default=False, help='f 1, takes images in order to make batches, otherwise takes them randomly') |
| 31 | +parser.add_argument('--serial_batch_iter', dest='serial_batch_iter', type=bool, default=True, help='iter into serial image list') |
| 32 | +parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='checkpoint', help='models are saved here') |
| 33 | +parser.add_argument('--sample_dir', dest='sample_dir', default='sample', help='sample are saved here') |
| 34 | +parser.add_argument('--test_dir', dest='test_dir', default='test', help='test sample are saved here') |
| 35 | +parser.add_argument('--log_dir', dest='log_dir', default='log', help='log are saved here') |
| 36 | +parser.add_argument('--save_dir', dest='save_dir', default='save', help='save all') |
| 37 | + |
| 38 | +args = parser.parse_args() |
| 39 | + |
| 40 | +gpu_memory_require = 4.0 |
| 41 | + |
| 42 | + |
| 43 | +def main(_): |
| 44 | + args.checkpoint_dir = os.path.join(args.save_dir, args.checkpoint_dir) |
| 45 | + args.sample_dir = os.path.join(args.save_dir, args.sample_dir) |
| 46 | + args.test_dir = os.path.join(args.save_dir, args.test_dir) |
| 47 | + args.log_dir = os.path.join(args.save_dir, args.log_dir) |
| 48 | + if not os.path.exists(args.checkpoint_dir): |
| 49 | + os.makedirs(args.checkpoint_dir) |
| 50 | + if not os.path.exists(args.sample_dir): |
| 51 | + os.makedirs(args.sample_dir) |
| 52 | + if not os.path.exists(args.test_dir): |
| 53 | + os.makedirs(args.test_dir) |
| 54 | + |
| 55 | + checkGPU.auto_queue( |
| 56 | + gpu_memory_require=gpu_memory_require, |
| 57 | + interval=1, |
| 58 | + ) |
| 59 | + config = checkGPU.set_memory_usage( |
| 60 | + usage=gpu_memory_require, |
| 61 | + allow_growth=True |
| 62 | + ) |
| 63 | + |
| 64 | + with tf.Session(config=config) as sess: |
| 65 | + model = pix2pix(sess, image_size=args.fine_size, batch_size=args.batch_size, |
| 66 | + output_size=args.fine_size, dataset_name=args.dataset_name, |
| 67 | + checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir) |
| 68 | + |
| 69 | + if args.phase == 'train': |
| 70 | + model.train(args) |
| 71 | + else: |
| 72 | + model.test(args) |
| 73 | + |
| 74 | +if __name__ == '__main__': |
| 75 | + |
| 76 | + if 0: |
| 77 | + print 'Run on CPU' |
| 78 | + with tf.device("/cpu:0"): |
| 79 | + gpu_memory_require = 0.0 |
| 80 | + tf.app.run() |
| 81 | + |
| 82 | + tf.app.run() |
0 commit comments