|
| 1 | +import glob |
| 2 | +import os |
| 3 | +import sys |
| 4 | +from itertools import product |
| 5 | +from pathlib import Path |
| 6 | +from typing import Literal, List, Optional, Tuple |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | +from omegaconf import OmegaConf |
| 11 | +from pytorch_lightning import seed_everything |
| 12 | +from torch import Tensor |
| 13 | +from torchvision.utils import save_image |
| 14 | +from tqdm import tqdm |
| 15 | + |
| 16 | +from scripts.make_samples import get_parser, load_model_and_dset |
| 17 | +from taming.data.conditional_builder.object_center_points_builder import CoordinatesCenterPointsConditionalBuilder |
| 18 | +from taming.data.helper_types import BoundingBox, Annotation |
| 19 | +from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset |
| 20 | +from taming.models.cond_transformer import Net2NetTransformer |
| 21 | + |
| 22 | +seed_everything(42424242) |
| 23 | +device: Literal['cuda', 'cpu'] = 'cuda' |
| 24 | +first_stage_factor = 16 |
| 25 | +trained_on_res = 256 |
| 26 | + |
| 27 | + |
| 28 | +def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int): |
| 29 | + assert 0 <= coord < coord_max |
| 30 | + coord_desired_center = (coord_window - 1) // 2 |
| 31 | + return np.clip(coord - coord_desired_center, 0, coord_max - coord_window) |
| 32 | + |
| 33 | + |
| 34 | +def get_crop_coordinates(x: int, y: int) -> BoundingBox: |
| 35 | + WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0] |
| 36 | + x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH |
| 37 | + y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT |
| 38 | + w = first_stage_factor / WIDTH |
| 39 | + h = first_stage_factor / HEIGHT |
| 40 | + return x0, y0, w, h |
| 41 | + |
| 42 | + |
| 43 | +def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor: |
| 44 | + WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0] |
| 45 | + x0 = _helper(predict_x, WIDTH, first_stage_factor) |
| 46 | + y0 = _helper(predict_y, HEIGHT, first_stage_factor) |
| 47 | + no_images = z_indices.shape[0] |
| 48 | + cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1)) |
| 49 | + cut_out_2 = z_indices[:, predict_y, x0:predict_x] |
| 50 | + return torch.cat((cut_out_1, cut_out_2), dim=1) |
| 51 | + |
| 52 | + |
| 53 | +@torch.no_grad() |
| 54 | +def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset, |
| 55 | + conditional_builder: CoordinatesCenterPointsConditionalBuilder, no_samples: int, |
| 56 | + temperature: float, top_k: int) -> Tensor: |
| 57 | + x_max, y_max = desired_z_shape[1], desired_z_shape[0] |
| 58 | + |
| 59 | + annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations] |
| 60 | + |
| 61 | + recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res)) |
| 62 | + if not recompute_conditional: |
| 63 | + crop_coordinates = get_crop_coordinates(0, 0) |
| 64 | + conditional_indices = conditional_builder.build(annotations, crop_coordinates) |
| 65 | + c_indices = conditional_indices.to(device).repeat(no_samples, 1) |
| 66 | + z_indices = torch.zeros((no_samples, 0), device=device).long() |
| 67 | + output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature, |
| 68 | + sample=True, top_k=top_k) |
| 69 | + else: |
| 70 | + output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long() |
| 71 | + for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max): |
| 72 | + crop_coordinates = get_crop_coordinates(predict_x, predict_y) |
| 73 | + z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y) |
| 74 | + conditional_indices = conditional_builder.build(annotations, crop_coordinates) |
| 75 | + c_indices = conditional_indices.to(device).repeat(no_samples, 1) |
| 76 | + new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k) |
| 77 | + output_indices[:, predict_y, predict_x] = new_index[:, -1] |
| 78 | + z_shape = ( |
| 79 | + no_samples, |
| 80 | + model.first_stage_model.quantize.e_dim, # codebook embed_dim |
| 81 | + desired_z_shape[0], # z_height |
| 82 | + desired_z_shape[1] # z_width |
| 83 | + ) |
| 84 | + x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5 |
| 85 | + x_sample = x_sample.to('cpu') |
| 86 | + |
| 87 | + plotter = conditional_builder.plot |
| 88 | + figure_size = (x_sample.shape[2], x_sample.shape[3]) |
| 89 | + scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.)) |
| 90 | + plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size) |
| 91 | + return torch.cat((x_sample, plot.unsqueeze(0))) |
| 92 | + |
| 93 | + |
| 94 | +def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]): |
| 95 | + if not resolution_str.count(',') == 1: |
| 96 | + raise ValueError("Give resolution as in 'height,width'") |
| 97 | + res_h, res_w = resolution_str.split(',') |
| 98 | + res_h = max(int(res_h), trained_on_res) |
| 99 | + res_w = max(int(res_w), trained_on_res) |
| 100 | + z_h = int(round(res_h/first_stage_factor)) |
| 101 | + z_w = int(round(res_w/first_stage_factor)) |
| 102 | + return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor) |
| 103 | + |
| 104 | + |
| 105 | +def add_arg_to_parser(parser): |
| 106 | + parser.add_argument( |
| 107 | + "-R", |
| 108 | + "--resolution", |
| 109 | + type=str, |
| 110 | + default='256,256', |
| 111 | + help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'", |
| 112 | + ) |
| 113 | + parser.add_argument( |
| 114 | + "-C", |
| 115 | + "--conditional", |
| 116 | + type=str, |
| 117 | + default='objects_bbox', |
| 118 | + help=f"objects_bbox or objects_center_points", |
| 119 | + ) |
| 120 | + parser.add_argument( |
| 121 | + "-N", |
| 122 | + "--n_samples_per_layout", |
| 123 | + type=int, |
| 124 | + default=4, |
| 125 | + help=f"how many samples to generate per layout", |
| 126 | + ) |
| 127 | + return parser |
| 128 | + |
| 129 | + |
| 130 | +if __name__ == "__main__": |
| 131 | + sys.path.append(os.getcwd()) |
| 132 | + |
| 133 | + parser = get_parser() |
| 134 | + parser = add_arg_to_parser(parser) |
| 135 | + |
| 136 | + opt, unknown = parser.parse_known_args() |
| 137 | + |
| 138 | + ckpt = None |
| 139 | + if opt.resume: |
| 140 | + if not os.path.exists(opt.resume): |
| 141 | + raise ValueError("Cannot find {}".format(opt.resume)) |
| 142 | + if os.path.isfile(opt.resume): |
| 143 | + paths = opt.resume.split("/") |
| 144 | + try: |
| 145 | + idx = len(paths)-paths[::-1].index("logs")+1 |
| 146 | + except ValueError: |
| 147 | + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt |
| 148 | + logdir = "/".join(paths[:idx]) |
| 149 | + ckpt = opt.resume |
| 150 | + else: |
| 151 | + assert os.path.isdir(opt.resume), opt.resume |
| 152 | + logdir = opt.resume.rstrip("/") |
| 153 | + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
| 154 | + print(f"logdir:{logdir}") |
| 155 | + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) |
| 156 | + opt.base = base_configs+opt.base |
| 157 | + |
| 158 | + if opt.config: |
| 159 | + if type(opt.config) == str: |
| 160 | + opt.base = [opt.config] |
| 161 | + else: |
| 162 | + opt.base = [opt.base[-1]] |
| 163 | + |
| 164 | + configs = [OmegaConf.load(cfg) for cfg in opt.base] |
| 165 | + cli = OmegaConf.from_dotlist(unknown) |
| 166 | + if opt.ignore_base_data: |
| 167 | + for config in configs: |
| 168 | + if hasattr(config, "data"): |
| 169 | + del config["data"] |
| 170 | + config = OmegaConf.merge(*configs, cli) |
| 171 | + desired_z_shape, desired_resolution = get_resolution(opt.resolution) |
| 172 | + conditional = opt.conditional |
| 173 | + |
| 174 | + print(ckpt) |
| 175 | + gpu = True |
| 176 | + eval_mode = True |
| 177 | + show_config = False |
| 178 | + if show_config: |
| 179 | + print(OmegaConf.to_container(config)) |
| 180 | + |
| 181 | + dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode) |
| 182 | + print(f"Global step: {global_step}") |
| 183 | + |
| 184 | + data_loader = dsets.val_dataloader() |
| 185 | + print(dsets.datasets["validation"].conditional_builders) |
| 186 | + conditional_builder = dsets.datasets["validation"].conditional_builders[conditional] |
| 187 | + |
| 188 | + outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}") |
| 189 | + outdir.mkdir(exist_ok=True, parents=True) |
| 190 | + print("Writing samples to ", outdir) |
| 191 | + |
| 192 | + p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader)) |
| 193 | + for batch_no, batch in p_bar_1: |
| 194 | + save_img: Optional[Tensor] = None |
| 195 | + for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size): |
| 196 | + imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder, |
| 197 | + opt.n_samples_per_layout, opt.temperature, opt.top_k) |
| 198 | + save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1) |
0 commit comments