|
| 1 | + |
| 2 | + |
| 3 | +from comet_ml import Experiment |
| 4 | + |
| 5 | +import tensorflow as tf |
| 6 | + |
| 7 | +import argparse |
| 8 | +import subprocess |
| 9 | +import os.path |
| 10 | + |
| 11 | +import logging |
| 12 | +import coloredlogs |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +from utils import * |
| 16 | + |
| 17 | + |
| 18 | +"""parsing and configuration""" |
| 19 | +def parse_args(): |
| 20 | + desc = "Tensorflow implementation of BigGAN" |
| 21 | + parser = argparse.ArgumentParser(description=desc) |
| 22 | + parser.add_argument('--tag' , action="append" , default=[]) |
| 23 | + parser.add_argument('--phase' , type=str , default='train' , help='train or test ?') |
| 24 | + |
| 25 | + parser.add_argument('--train-input-path' , type=str , default='./datasets/imagenet/train*') |
| 26 | + parser.add_argument('--eval-input-path' , type=str , default='./datasets/imagenet/validate*') |
| 27 | + parser.add_argument('--tfr-format' , type=str , default='inception', choices=['inception', 'progan']) |
| 28 | + |
| 29 | + parser.add_argument('--model-dir' , type=str , default='model') |
| 30 | + parser.add_argument('--result-dir' , type=str , default='results') |
| 31 | + |
| 32 | + # SAGAN |
| 33 | + # batch_size = 256 |
| 34 | + # base channel = 64 |
| 35 | + # epoch = 100 (1M iterations) |
| 36 | + # self-attn-res = [64] |
| 37 | + |
| 38 | + parser.add_argument('--img-size' , type=int , default=128 , help='The width and height of the input/output image') |
| 39 | + parser.add_argument('--img-ch' , type=int , default=3 , help='The number of channels in the input/output image') |
| 40 | + |
| 41 | + parser.add_argument('--epochs' , type=int , default=100 , help='The number of training iterations') |
| 42 | + parser.add_argument('--train-steps' , type=int , default=10000 , help='The number of training iterations') |
| 43 | + parser.add_argument('--eval-steps' , type=int , default=100 , help='The number of eval iterations') |
| 44 | + parser.add_argument('--batch-size' , type=int , default=2048 , dest="_batch_size" , help='The size of batch across all GPUs') |
| 45 | + parser.add_argument('--shuffle-buffer' , type=int , default=4000 ) |
| 46 | + |
| 47 | + |
| 48 | + parser.add_argument('--ch' , type=int , default=96 , help='base channel number per layer') |
| 49 | + parser.add_argument('--layers' , type=int , default=5 ) |
| 50 | + |
| 51 | + parser.add_argument('--use-tpu' , action='store_true') |
| 52 | + parser.add_argument('--tpu-name' , action='append' , default=[] ) |
| 53 | + parser.add_argument('--tpu-zone' , type=str, default='us-central1-f') |
| 54 | + parser.add_argument('--num-shards' , type=int , default=8) # A single TPU has 8 shards |
| 55 | + parser.add_argument('--steps-per-loop' , type=int , default=10000) |
| 56 | + |
| 57 | + parser.add_argument('--disable-comet' , action='store_false', dest='use_comet') |
| 58 | + |
| 59 | + parser.add_argument('--self-attn-res' , action='append', default=[] ) |
| 60 | + |
| 61 | + parser.add_argument('--g-lr' , type=float , default=0.00005 , help='learning rate for generator') |
| 62 | + parser.add_argument('--d-lr' , type=float , default=0.0002 , help='learning rate for discriminator') |
| 63 | + |
| 64 | + # if lower batch size |
| 65 | + # g_lr = 0.0001 |
| 66 | + # d_lr = 0.0004 |
| 67 | + |
| 68 | + # if larger batch size |
| 69 | + # g_lr = 0.00005 |
| 70 | + # d_lr = 0.0002 |
| 71 | + |
| 72 | + parser.add_argument('--beta1' , type=float , default=0.0 , help='beta1 for Adam optimizer') |
| 73 | + parser.add_argument('--beta2' , type=float , default=0.9 , help='beta2 for Adam optimizer') |
| 74 | + parser.add_argument('--moving-decay' , type=float , default=0.9999 , help='moving average decay for generator') |
| 75 | + |
| 76 | + parser.add_argument('--z-dim' , type=int , default=128 , help='Dimension of noise vector') |
| 77 | + parser.add_argument('--sn' , type=str2bool , default=True , help='using spectral norm') |
| 78 | + |
| 79 | + parser.add_argument('--gan-type' , type=str , default='hinge' , help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]') |
| 80 | + parser.add_argument('--ld' , type=float , default=10.0 , help='The gradient penalty lambda') |
| 81 | + parser.add_argument('--n-critic' , type=int , default=2 , help='The number of critic') |
| 82 | + |
| 83 | + # IGoodfellow says sould be 50k |
| 84 | + parser.add_argument('--inception-score-num' , type=int , default=512 , help='The number of sample images to use in inception score') |
| 85 | + parser.add_argument('--sample-num' , type=int , default=36 , help='The number of sample images to save') |
| 86 | + parser.add_argument('--test-num' , type=int , default=10 , help='The number of images generated by the test') |
| 87 | + |
| 88 | + parser.add_argument('--verbosity', type=str, default='WARNING') |
| 89 | + |
| 90 | + args = parser.parse_args() |
| 91 | + return check_args(args) |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | +def check_args(args): |
| 96 | + tf.gfile.MakeDirs(suffixed_folder(args, args.result_dir)) |
| 97 | + tf.gfile.MakeDirs("./temp/") |
| 98 | + |
| 99 | + assert args.epochs >= 1, "number of epochs must be larger than or equal to one" |
| 100 | + assert args._batch_size >= 1, "batch size must be larger than or equal to one" |
| 101 | + assert args.ch >= 8, "--ch cannot be less than 8 otherwise some dimensions of the network will be size 0" |
| 102 | + |
| 103 | + if args.use_tpu: |
| 104 | + assert len(args.tpu_name) > 0, "Please provide at least one --tpu-name" |
| 105 | + |
| 106 | + return args |
| 107 | + |
| 108 | + |
| 109 | + |
| 110 | +def model_dir(args): |
| 111 | + return os.path.join(args.model_dir, *args.tag, model_name(args)) |
| 112 | + |
| 113 | + |
| 114 | + |
| 115 | + |
| 116 | + |
| 117 | +def setup_logging(args): |
| 118 | + |
| 119 | + coloredlogs.install(level='INFO', logger=logger) |
| 120 | + coloredlogs.install(level='INFO', logger=logging.getLogger('main_tpu')) |
| 121 | + coloredlogs.install(level='INFO', logger=logging.getLogger('ops')) |
| 122 | + coloredlogs.install(level='INFO', logger=logging.getLogger('utils')) |
| 123 | + coloredlogs.install(level='INFO', logger=logging.getLogger('BigGAN_128')) |
| 124 | + |
| 125 | + tf.logging.set_verbosity(args.verbosity) |
| 126 | + |
| 127 | + # log = logging.getLogger() |
| 128 | + # log_path = os.path.join(suffixed_folder(args, args.result_dir), 'log.txt') |
| 129 | + # stream = tf.gfile.Open(log_path, 'a') |
| 130 | + # fh = logging.StreamHandler(stream=stream) |
| 131 | + # fh.setLevel(logging.INFO) |
| 132 | + # formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s') |
| 133 | + # fh.setFormatter(formatter) |
| 134 | + # log.addHandler(fh) |
| 135 | + |
| 136 | + logger.info(f"cmd args: {vars(args)}") |
| 137 | + |
0 commit comments