Skip to content

Commit 2908a53

Browse files
authored
Merge pull request #97 from CompVis/scene-images-coco
Added scene image generation for COCO 🌆
2 parents 141eb74 + 6194bd1 commit 2908a53

File tree

221 files changed

+3794
-10
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

221 files changed

+3794
-10
lines changed

assets/coco_scene_images_training.svg

+2,574

configs/coco_cond_stage.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ model:
3030
codebook_weight: 1.0
3131

3232
data:
33-
target: cutlit.DataModuleFromConfig
33+
target: main.DataModuleFromConfig
3434
params:
3535
batch_size: 12
3636
train:
@@ -41,7 +41,7 @@ data:
4141
onehot_segmentation: true
4242
use_stuffthing: true
4343
validation:
44-
target: taming.data.coco.CocoImagesAndCaptionsTrain
44+
target: taming.data.coco.CocoImagesAndCaptionsValidation
4545
params:
4646
size: 256
4747
crop_size: 256
+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
model:
2+
base_learning_rate: 4.5e-06
3+
target: taming.models.cond_transformer.Net2NetTransformer
4+
params:
5+
cond_stage_key: objects_bbox
6+
transformer_config:
7+
target: taming.modules.transformer.mingpt.GPT
8+
params:
9+
vocab_size: 8192
10+
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11+
n_layer: 40
12+
n_head: 16
13+
n_embd: 1408
14+
embd_pdrop: 0.1
15+
resid_pdrop: 0.1
16+
attn_pdrop: 0.1
17+
first_stage_config:
18+
target: taming.models.vqgan.VQModel
19+
params:
20+
ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
21+
embed_dim: 256
22+
n_embed: 8192
23+
ddconfig:
24+
double_z: false
25+
z_channels: 256
26+
resolution: 256
27+
in_channels: 3
28+
out_ch: 3
29+
ch: 128
30+
ch_mult:
31+
- 1
32+
- 1
33+
- 2
34+
- 2
35+
- 4
36+
num_res_blocks: 2
37+
attn_resolutions:
38+
- 16
39+
dropout: 0.0
40+
lossconfig:
41+
target: taming.modules.losses.DummyLoss
42+
cond_stage_config:
43+
target: taming.models.dummy_cond_stage.DummyCondStage
44+
params:
45+
conditional_key: objects_bbox
46+
47+
data:
48+
target: main.DataModuleFromConfig
49+
params:
50+
batch_size: 6
51+
num_workers: 12
52+
train:
53+
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
54+
params:
55+
data_path: data/coco_annotations_100
56+
split: train
57+
keys: [image, objects_bbox, file_name]
58+
no_tokens: 8192
59+
target_image_size: 256
60+
min_object_area: 0.00001
61+
min_objects_per_image: 2
62+
max_objects_per_image: 30
63+
crop_method: random-1d
64+
random_flip: true
65+
use_group_parameter: true
66+
encode_crop: true
67+
validation:
68+
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
69+
params:
70+
data_path: data/coco_annotations_100
71+
split: validation
72+
keys: [image, objects_bbox, file_name]
73+
no_tokens: 8192
74+
target_image_size: 256
75+
min_object_area: 0.00001
76+
min_objects_per_image: 2
77+
max_objects_per_image: 30
78+
crop_method: center
79+
random_flip: false
80+
use_group_parameter: true
81+
encode_crop: true

data/coco_annotations_100/annotations/instances_train2017.json

+1
Large diffs are not rendered by default.

data/coco_annotations_100/annotations/instances_val2017.json

+1
Large diffs are not rendered by default.

data/coco_annotations_100/annotations/stuff_train2017.json

+1
Large diffs are not rendered by default.

data/coco_annotations_100/annotations/stuff_val2017.json

+1
Large diffs are not rendered by default.

environment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ dependencies:
2020
- test-tube>=0.7.5
2121
- streamlit>=0.73.1
2222
- einops==0.3.0
23+
- more-itertools>=8.0.0
2324
- transformers==4.3.1
2425
- -e .

main.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
1212
from pytorch_lightning.utilities.distributed import rank_zero_only
1313

14+
from taming.data.utils import custom_collate
15+
16+
1417
def get_obj_from_str(string, reload=False):
1518
module, cls = string.rsplit(".", 1)
1619
if reload:
@@ -160,16 +163,16 @@ def setup(self, stage=None):
160163

161164
def _train_dataloader(self):
162165
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
163-
num_workers=self.num_workers, shuffle=True)
166+
num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
164167

165168
def _val_dataloader(self):
166169
return DataLoader(self.datasets["validation"],
167170
batch_size=self.batch_size,
168-
num_workers=self.num_workers)
171+
num_workers=self.num_workers, collate_fn=custom_collate)
169172

170173
def _test_dataloader(self):
171174
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
172-
num_workers=self.num_workers)
175+
num_workers=self.num_workers, collate_fn=custom_collate)
173176

174177

175178
class SetupCallback(Callback):
@@ -278,7 +281,7 @@ def log_img(self, pl_module, batch, batch_idx, split="train"):
278281
pl_module.eval()
279282

280283
with torch.no_grad():
281-
images = pl_module.log_images(batch, split=split)
284+
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
282285

283286
for k in images:
284287
N = min(images[k].shape[0], self.max_images)

scripts/make_scene_samples.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import glob
2+
import os
3+
import sys
4+
from itertools import product
5+
from pathlib import Path
6+
from typing import Literal, List, Optional, Tuple
7+
8+
import numpy as np
9+
import torch
10+
from omegaconf import OmegaConf
11+
from pytorch_lightning import seed_everything
12+
from torch import Tensor
13+
from torchvision.utils import save_image
14+
from tqdm import tqdm
15+
16+
from scripts.make_samples import get_parser, load_model_and_dset
17+
from taming.data.conditional_builder.object_center_points_builder import CoordinatesCenterPointsConditionalBuilder
18+
from taming.data.helper_types import BoundingBox, Annotation
19+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
20+
from taming.models.cond_transformer import Net2NetTransformer
21+
22+
seed_everything(42424242)
23+
device: Literal['cuda', 'cpu'] = 'cuda'
24+
first_stage_factor = 16
25+
trained_on_res = 256
26+
27+
28+
def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
29+
assert 0 <= coord < coord_max
30+
coord_desired_center = (coord_window - 1) // 2
31+
return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
32+
33+
34+
def get_crop_coordinates(x: int, y: int) -> BoundingBox:
35+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
36+
x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
37+
y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
38+
w = first_stage_factor / WIDTH
39+
h = first_stage_factor / HEIGHT
40+
return x0, y0, w, h
41+
42+
43+
def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
44+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
45+
x0 = _helper(predict_x, WIDTH, first_stage_factor)
46+
y0 = _helper(predict_y, HEIGHT, first_stage_factor)
47+
no_images = z_indices.shape[0]
48+
cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
49+
cut_out_2 = z_indices[:, predict_y, x0:predict_x]
50+
return torch.cat((cut_out_1, cut_out_2), dim=1)
51+
52+
53+
@torch.no_grad()
54+
def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
55+
conditional_builder: CoordinatesCenterPointsConditionalBuilder, no_samples: int,
56+
temperature: float, top_k: int) -> Tensor:
57+
x_max, y_max = desired_z_shape[1], desired_z_shape[0]
58+
59+
annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
60+
61+
recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
62+
if not recompute_conditional:
63+
crop_coordinates = get_crop_coordinates(0, 0)
64+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
65+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
66+
z_indices = torch.zeros((no_samples, 0), device=device).long()
67+
output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
68+
sample=True, top_k=top_k)
69+
else:
70+
output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
71+
for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
72+
crop_coordinates = get_crop_coordinates(predict_x, predict_y)
73+
z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
74+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
75+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
76+
new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
77+
output_indices[:, predict_y, predict_x] = new_index[:, -1]
78+
z_shape = (
79+
no_samples,
80+
model.first_stage_model.quantize.e_dim, # codebook embed_dim
81+
desired_z_shape[0], # z_height
82+
desired_z_shape[1] # z_width
83+
)
84+
x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
85+
x_sample = x_sample.to('cpu')
86+
87+
plotter = conditional_builder.plot
88+
figure_size = (x_sample.shape[2], x_sample.shape[3])
89+
scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
90+
plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
91+
return torch.cat((x_sample, plot.unsqueeze(0)))
92+
93+
94+
def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
95+
if not resolution_str.count(',') == 1:
96+
raise ValueError("Give resolution as in 'height,width'")
97+
res_h, res_w = resolution_str.split(',')
98+
res_h = max(int(res_h), trained_on_res)
99+
res_w = max(int(res_w), trained_on_res)
100+
z_h = int(round(res_h/first_stage_factor))
101+
z_w = int(round(res_w/first_stage_factor))
102+
return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
103+
104+
105+
def add_arg_to_parser(parser):
106+
parser.add_argument(
107+
"-R",
108+
"--resolution",
109+
type=str,
110+
default='256,256',
111+
help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
112+
)
113+
parser.add_argument(
114+
"-C",
115+
"--conditional",
116+
type=str,
117+
default='objects_bbox',
118+
help=f"objects_bbox or objects_center_points",
119+
)
120+
parser.add_argument(
121+
"-N",
122+
"--n_samples_per_layout",
123+
type=int,
124+
default=4,
125+
help=f"how many samples to generate per layout",
126+
)
127+
return parser
128+
129+
130+
if __name__ == "__main__":
131+
sys.path.append(os.getcwd())
132+
133+
parser = get_parser()
134+
parser = add_arg_to_parser(parser)
135+
136+
opt, unknown = parser.parse_known_args()
137+
138+
ckpt = None
139+
if opt.resume:
140+
if not os.path.exists(opt.resume):
141+
raise ValueError("Cannot find {}".format(opt.resume))
142+
if os.path.isfile(opt.resume):
143+
paths = opt.resume.split("/")
144+
try:
145+
idx = len(paths)-paths[::-1].index("logs")+1
146+
except ValueError:
147+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
148+
logdir = "/".join(paths[:idx])
149+
ckpt = opt.resume
150+
else:
151+
assert os.path.isdir(opt.resume), opt.resume
152+
logdir = opt.resume.rstrip("/")
153+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
154+
print(f"logdir:{logdir}")
155+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
156+
opt.base = base_configs+opt.base
157+
158+
if opt.config:
159+
if type(opt.config) == str:
160+
opt.base = [opt.config]
161+
else:
162+
opt.base = [opt.base[-1]]
163+
164+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
165+
cli = OmegaConf.from_dotlist(unknown)
166+
if opt.ignore_base_data:
167+
for config in configs:
168+
if hasattr(config, "data"):
169+
del config["data"]
170+
config = OmegaConf.merge(*configs, cli)
171+
desired_z_shape, desired_resolution = get_resolution(opt.resolution)
172+
conditional = opt.conditional
173+
174+
print(ckpt)
175+
gpu = True
176+
eval_mode = True
177+
show_config = False
178+
if show_config:
179+
print(OmegaConf.to_container(config))
180+
181+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
182+
print(f"Global step: {global_step}")
183+
184+
data_loader = dsets.val_dataloader()
185+
print(dsets.datasets["validation"].conditional_builders)
186+
conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
187+
188+
outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
189+
outdir.mkdir(exist_ok=True, parents=True)
190+
print("Writing samples to ", outdir)
191+
192+
p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
193+
for batch_no, batch in p_bar_1:
194+
save_img: Optional[Tensor] = None
195+
for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
196+
imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
197+
opt.n_samples_per_layout, opt.temperature, opt.top_k)
198+
save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)

0 commit comments

Comments
 (0)