Skip to content

Commit d929c0e

Browse files
committed
d-rin
1 parent 837ac3f commit d929c0e

File tree

7 files changed

+869
-2
lines changed

7 files changed

+869
-2
lines changed

README.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,47 @@ conda activate taming
2424

2525
## Data Preparation
2626

27+
### ImageNet
28+
The code will try to download (through [Academic
29+
Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
30+
is used. However, since ImageNet is quite large, this requires a lot of disk
31+
space and time. If you already have ImageNet on your disk, you can speed things
32+
up by putting the data into
33+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
34+
`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
35+
of `train`/`validation`. It should have the following structure:
36+
37+
```
38+
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
39+
├── n01440764
40+
│ ├── n01440764_10026.JPEG
41+
│ ├── n01440764_10027.JPEG
42+
│ ├── ...
43+
├── n01443537
44+
│ ├── n01443537_10007.JPEG
45+
│ ├── n01443537_10014.JPEG
46+
│ ├── ...
47+
├── ...
48+
```
49+
50+
If you haven't extracted the data, you can also place
51+
`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
52+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
53+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
54+
extracted into above structure without downloading it again. Note that this
55+
will only happen if neither a folder
56+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
57+
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
58+
if you want to force running the dataset preparation again.
59+
60+
You will then need to prepare the depth data using
61+
[MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
62+
`data/imagenet_depth` pointing to a folder with two subfolders `train` and
63+
`val`, each mirroring the structure of the corresponding ImageNet folder
64+
described above and containing a `png` file for each of ImageNet's `JPEG`
65+
files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
66+
images. We provide the script `TODO` to generate this data.
67+
2768
### CelebA-HQ
2869
Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
2970
files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
@@ -47,6 +88,12 @@ Download [2020-11-09T13-31-51_sflckr](TODO) and place it into `logs`. Run
4788
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
4889
```
4990

91+
### D-RIN
92+
Download [2020-11-20T12-54-32_drin_transformer](TODO) and place it into `logs`. Run
93+
```
94+
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
95+
```
96+
5097
## Training models
5198

5299
### FacesHQ
@@ -65,6 +112,33 @@ corresponds to the preconfigured checkpoint path), then run
65112
python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
66113
```
67114

115+
### D-RIN
116+
117+
Train a VQGAN on ImageNet with
118+
```
119+
python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
120+
```
121+
122+
or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](TODO)
123+
and place under `logs`. If you trained your own, adjust the path in the config
124+
key `model.params.first_stage_config.params.ckpt_path` of
125+
`configs/drin_transformer.yaml`.
126+
127+
Train a VQGAN on Depth Maps of ImageNet with
128+
```
129+
python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
130+
```
131+
132+
or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](TODO)
133+
and place under `logs`. If you trained your own, adjust the path in the config
134+
key `model.params.cond_stage_config.params.ckpt_path` of
135+
`configs/drin_transformer.yaml`.
136+
137+
To train the transformer, run
138+
```
139+
python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
140+
```
141+
68142
## Shout-outs
69143
Thanks to everyone who makes their code and models available. In particular,
70144

configs/drin_transformer.yaml

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
model:
2+
base_learning_rate: 4.5e-06
3+
target: taming.models.cond_transformer.Net2NetTransformer
4+
params:
5+
cond_stage_key: depth
6+
transformer_config:
7+
target: taming.modules.transformer.mingpt.GPT
8+
params:
9+
vocab_size: 1024
10+
block_size: 512
11+
n_layer: 24
12+
n_head: 16
13+
n_embd: 1024
14+
first_stage_config:
15+
target: taming.models.vqgan.VQModel
16+
params:
17+
ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt
18+
embed_dim: 256
19+
n_embed: 1024
20+
ddconfig:
21+
double_z: false
22+
z_channels: 256
23+
resolution: 256
24+
in_channels: 3
25+
out_ch: 3
26+
ch: 128
27+
ch_mult:
28+
- 1
29+
- 1
30+
- 2
31+
- 2
32+
- 4
33+
num_res_blocks: 2
34+
attn_resolutions:
35+
- 16
36+
dropout: 0.0
37+
lossconfig:
38+
target: taming.modules.losses.DummyLoss
39+
cond_stage_config:
40+
target: taming.models.vqgan.VQModel
41+
params:
42+
ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt
43+
embed_dim: 256
44+
n_embed: 1024
45+
ddconfig:
46+
double_z: false
47+
z_channels: 256
48+
resolution: 256
49+
in_channels: 1
50+
out_ch: 1
51+
ch: 128
52+
ch_mult:
53+
- 1
54+
- 1
55+
- 2
56+
- 2
57+
- 4
58+
num_res_blocks: 2
59+
attn_resolutions:
60+
- 16
61+
dropout: 0.0
62+
lossconfig:
63+
target: taming.modules.losses.DummyLoss
64+
65+
data:
66+
target: main.DataModuleFromConfig
67+
params:
68+
batch_size: 2
69+
num_workers: 8
70+
train:
71+
target: taming.data.imagenet.RINTrainWithDepth
72+
params:
73+
size: 256
74+
validation:
75+
target: taming.data.imagenet.RINValidationWithDepth
76+
params:
77+
size: 256

configs/imagenet_vqgan.yaml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
model:
2+
base_learning_rate: 4.5e-6
3+
target: taming.models.vqgan.VQModel
4+
params:
5+
embed_dim: 256
6+
n_embed: 1024
7+
ddconfig:
8+
double_z: False
9+
z_channels: 256
10+
resolution: 256
11+
in_channels: 3
12+
out_ch: 3
13+
ch: 128
14+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15+
num_res_blocks: 2
16+
attn_resolutions: [16]
17+
dropout: 0.0
18+
19+
lossconfig:
20+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21+
params:
22+
disc_conditional: False
23+
disc_in_channels: 3
24+
disc_start: 250001
25+
disc_weight: 0.8
26+
codebook_weight: 1.0
27+
28+
data:
29+
target: main.DataModuleFromConfig
30+
params:
31+
batch_size: 3
32+
num_workers: 8
33+
train:
34+
target: taming.data.imagenet.ImageNetTrain
35+
params:
36+
config:
37+
size: 256
38+
validation:
39+
target: taming.data.imagenet.ImageNetValidation
40+
params:
41+
config:
42+
size: 256

configs/imagenetdepth_vqgan.yaml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
model:
2+
base_learning_rate: 4.5e-6
3+
target: taming.models.vqgan.VQModel
4+
params:
5+
embed_dim: 256
6+
n_embed: 1024
7+
image_key: depth
8+
ddconfig:
9+
double_z: False
10+
z_channels: 256
11+
resolution: 256
12+
in_channels: 1
13+
out_ch: 1
14+
ch: 128
15+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
16+
num_res_blocks: 2
17+
attn_resolutions: [16]
18+
dropout: 0.0
19+
20+
lossconfig:
21+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
22+
params:
23+
disc_conditional: False
24+
disc_in_channels: 1
25+
disc_start: 50001
26+
disc_weight: 0.75
27+
codebook_weight: 1.0
28+
29+
data:
30+
target: main.DataModuleFromConfig
31+
params:
32+
batch_size: 3
33+
num_workers: 8
34+
train:
35+
target: taming.data.imagenet.ImageNetTrainWithDepth
36+
params:
37+
size: 256
38+
validation:
39+
target: taming.data.imagenet.ImageNetValidationWithDepth
40+
params:
41+
size: 256

taming/data/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ def __getitem__(self, idx):
2121

2222

2323
class ImagePaths(Dataset):
24-
def __init__(self, paths, size=None, random_crop=False):
24+
def __init__(self, paths, size=None, random_crop=False, labels=None):
2525
self.size = size
2626
self.random_crop = random_crop
2727

28-
self.labels = dict()
28+
self.labels = dict() if labels is None else labels
2929
self.labels["file_path_"] = paths
3030
self._length = len(paths)
3131

0 commit comments

Comments
 (0)