Skip to content

Commit c2105be

Browse files
now works with imagenet
1 parent 265a182 commit c2105be

13 files changed

+828
-193
lines changed

.DS_Store

0 Bytes
Binary file not shown.

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ nohup*
99
model
1010
samples
1111
results
12-
datasets
12+
datasets*
1313
logs
1414
dataset
1515
checkpoint

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2019 Junho Kim (1993.01.12)
3+
Copyright (c) 2019 David Mack
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

+25-10
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
1-
# BigGAN-Tensorflow-TPU
1+
# BigGAN Tensorflow TPU
22

3-
**This is a half-finished TPU conversion of [Junho Kim's](https://github.com/taki0112/BigGAN-Tensorflow) implementation. Only 128x128 supported ATM**
4-
5-
Simple Tensorflow implementation of ["Large Scale GAN Training for High Fidelity Natural Image Synthesis" (BigGAN)](https://arxiv.org/abs/1809.11096)
3+
Simple Tensorflow TPU implementation of ["Large Scale GAN Training for High Fidelity Natural Image Synthesis" (BigGAN)](https://arxiv.org/abs/1809.11096)
64

75
![main](./assets/main.png)
86

9-
## Issue
10-
* **The paper** used `orthogonal initialization`, but `I used random normal initialization.` The reason is, when using the orthogonal initialization, it did not train properly.
11-
* I have applied a hierarchical latent space, but **not** a class embeddedding.
7+
## Implementation notes/issues
8+
9+
- This is a half-finished TPU conversion of [Junho Kim's](https://github.com/taki0112/BigGAN-Tensorflow) implementation. Only 128x128 supported
10+
- **The paper** used `orthogonal initialization`, but `I used random normal initialization.` The reason is, when using the orthogonal initialization, it did not train properly.
11+
- I have applied a hierarchical latent space, but **not** a class embeddedding.
1212

1313
## Usage
1414

15-
### train
16-
* pipenv run ./launch_tpu_8.sh
15+
### Building the data
16+
17+
For ImageNet, use [TensorFlow's build scripts](https://github.com/tensorflow/models/blob/master/research/inception/README.md#getting-started) to create TFRecord files of your chosen image size (e.g. 128x128). `--tfr-format inception`
18+
19+
You can also use the data build script from [NVidia's Progressive Growing of GANs](https://github.com/tkarras/progressive_growing_of_gans). `--tfr-format progan`
20+
21+
### Training
22+
23+
You can train on a Google TPU by setting the name of your TPU as an env var and running one of the training scripts. For example,
24+
25+
* `TPU_NAME=node-1 pipenv run ./launch_train_tpu_b128.sh`
26+
27+
You need to have your training data stored on a Google cloud bucket.
1728

1829

1930
## Architecture
@@ -28,5 +39,9 @@ Simple Tensorflow implementation of ["Large Scale GAN Training for High Fidelity
2839
### 512x512
2940
<img src = './assets/512.png' width = '600px'>
3041

31-
## Author
42+
## Contributing
43+
44+
You're very welcome to! Submit a PR or [contact the author(s)](https://octavian.ai)
45+
46+
## Authors
3247
Junho Kim, David Mack

args.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
2+
3+
from comet_ml import Experiment
4+
5+
import tensorflow as tf
6+
7+
import argparse
8+
import subprocess
9+
import os.path
10+
11+
import logging
12+
import coloredlogs
13+
logger = logging.getLogger(__name__)
14+
15+
from utils import *
16+
17+
18+
"""parsing and configuration"""
19+
def parse_args():
20+
desc = "Tensorflow implementation of BigGAN"
21+
parser = argparse.ArgumentParser(description=desc)
22+
parser.add_argument('--tag' , action="append" , default=[])
23+
parser.add_argument('--phase' , type=str , default='train' , help='train or test ?')
24+
25+
parser.add_argument('--train-input-path' , type=str , default='./datasets/imagenet/train*')
26+
parser.add_argument('--eval-input-path' , type=str , default='./datasets/imagenet/validate*')
27+
parser.add_argument('--tfr-format' , type=str , default='inception', choices=['inception', 'progan'])
28+
29+
parser.add_argument('--model-dir' , type=str , default='model')
30+
parser.add_argument('--result-dir' , type=str , default='results')
31+
32+
# SAGAN
33+
# batch_size = 256
34+
# base channel = 64
35+
# epoch = 100 (1M iterations)
36+
# self-attn-res = [64]
37+
38+
parser.add_argument('--img-size' , type=int , default=128 , help='The width and height of the input/output image')
39+
parser.add_argument('--img-ch' , type=int , default=3 , help='The number of channels in the input/output image')
40+
41+
parser.add_argument('--epochs' , type=int , default=100 , help='The number of training iterations')
42+
parser.add_argument('--train-steps' , type=int , default=10000 , help='The number of training iterations')
43+
parser.add_argument('--eval-steps' , type=int , default=100 , help='The number of eval iterations')
44+
parser.add_argument('--batch-size' , type=int , default=2048 , dest="_batch_size" , help='The size of batch across all GPUs')
45+
parser.add_argument('--shuffle-buffer' , type=int , default=4000 )
46+
47+
48+
parser.add_argument('--ch' , type=int , default=96 , help='base channel number per layer')
49+
parser.add_argument('--layers' , type=int , default=5 )
50+
51+
parser.add_argument('--use-tpu' , action='store_true')
52+
parser.add_argument('--tpu-name' , action='append' , default=[] )
53+
parser.add_argument('--tpu-zone' , type=str, default='us-central1-f')
54+
parser.add_argument('--num-shards' , type=int , default=8) # A single TPU has 8 shards
55+
parser.add_argument('--steps-per-loop' , type=int , default=10000)
56+
57+
parser.add_argument('--disable-comet' , action='store_false', dest='use_comet')
58+
59+
parser.add_argument('--self-attn-res' , action='append', default=[] )
60+
61+
parser.add_argument('--g-lr' , type=float , default=0.00005 , help='learning rate for generator')
62+
parser.add_argument('--d-lr' , type=float , default=0.0002 , help='learning rate for discriminator')
63+
64+
# if lower batch size
65+
# g_lr = 0.0001
66+
# d_lr = 0.0004
67+
68+
# if larger batch size
69+
# g_lr = 0.00005
70+
# d_lr = 0.0002
71+
72+
parser.add_argument('--beta1' , type=float , default=0.0 , help='beta1 for Adam optimizer')
73+
parser.add_argument('--beta2' , type=float , default=0.9 , help='beta2 for Adam optimizer')
74+
parser.add_argument('--moving-decay' , type=float , default=0.9999 , help='moving average decay for generator')
75+
76+
parser.add_argument('--z-dim' , type=int , default=128 , help='Dimension of noise vector')
77+
parser.add_argument('--sn' , type=str2bool , default=True , help='using spectral norm')
78+
79+
parser.add_argument('--gan-type' , type=str , default='hinge' , help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
80+
parser.add_argument('--ld' , type=float , default=10.0 , help='The gradient penalty lambda')
81+
parser.add_argument('--n-critic' , type=int , default=2 , help='The number of critic')
82+
83+
# IGoodfellow says sould be 50k
84+
parser.add_argument('--inception-score-num' , type=int , default=512 , help='The number of sample images to use in inception score')
85+
parser.add_argument('--sample-num' , type=int , default=36 , help='The number of sample images to save')
86+
parser.add_argument('--test-num' , type=int , default=10 , help='The number of images generated by the test')
87+
88+
parser.add_argument('--verbosity', type=str, default='WARNING')
89+
90+
args = parser.parse_args()
91+
return check_args(args)
92+
93+
94+
95+
def check_args(args):
96+
tf.gfile.MakeDirs(suffixed_folder(args, args.result_dir))
97+
tf.gfile.MakeDirs("./temp/")
98+
99+
assert args.epochs >= 1, "number of epochs must be larger than or equal to one"
100+
assert args._batch_size >= 1, "batch size must be larger than or equal to one"
101+
assert args.ch >= 8, "--ch cannot be less than 8 otherwise some dimensions of the network will be size 0"
102+
103+
if args.use_tpu:
104+
assert len(args.tpu_name) > 0, "Please provide at least one --tpu-name"
105+
106+
return args
107+
108+
109+
110+
def model_dir(args):
111+
return os.path.join(args.model_dir, *args.tag, model_name(args))
112+
113+
114+
115+
116+
117+
def setup_logging(args):
118+
119+
coloredlogs.install(level='INFO', logger=logger)
120+
coloredlogs.install(level='INFO', logger=logging.getLogger('main_tpu'))
121+
coloredlogs.install(level='INFO', logger=logging.getLogger('ops'))
122+
coloredlogs.install(level='INFO', logger=logging.getLogger('utils'))
123+
coloredlogs.install(level='INFO', logger=logging.getLogger('BigGAN_128'))
124+
125+
tf.logging.set_verbosity(args.verbosity)
126+
127+
# log = logging.getLogger()
128+
# log_path = os.path.join(suffixed_folder(args, args.result_dir), 'log.txt')
129+
# stream = tf.gfile.Open(log_path, 'a')
130+
# fh = logging.StreamHandler(stream=stream)
131+
# fh.setLevel(logging.INFO)
132+
# formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s')
133+
# fh.setFormatter(formatter)
134+
# log.addHandler(fh)
135+
136+
logger.info(f"cmd args: {vars(args)}")
137+

0 commit comments

Comments
 (0)