Skip to content

update the file /traming-transformers/taming/data/utils.py #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@

This is a fork of the GitHub repository https://github.com/CompVis/taming-transformers. Due to recent changes in PyTorch, the original CompVis/taming-transformers repository requires small updates to remain compatible. I need to use this repo in Chapter 11 of my book Build a Text-to-Image Generator from Scratch with Manning Publications. Rather than asking every reader to manually edit source code, I have created a fork of the repository with these compatibility fixes already applied.

In the file /traming-transformers/taming/data/utils.py, I have changed string_classes to str in line 152. After that, I deleted line 11 of the file (the line that says "from torch._six import string_classes").




# Taming Transformers for High-Resolution Image Synthesis
##### CVPR 2021 (Oral)
![teaser](assets/mountain.jpeg)
Expand Down
338 changes: 169 additions & 169 deletions taming/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,169 +1,169 @@
import collections
import os
import tarfile
import urllib
import zipfile
from pathlib import Path

import numpy as np
import torch
from taming.data.helper_types import Annotation
from torch._six import string_classes
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
from tqdm import tqdm


def unpack(path):
if path.endswith("tar.gz"):
with tarfile.open(path, "r:gz") as tar:
tar.extractall(path=os.path.split(path)[0])
elif path.endswith("tar"):
with tarfile.open(path, "r:") as tar:
tar.extractall(path=os.path.split(path)[0])
elif path.endswith("zip"):
with zipfile.ZipFile(path, "r") as f:
f.extractall(path=os.path.split(path)[0])
else:
raise NotImplementedError(
"Unknown file extension: {}".format(os.path.splitext(path)[1])
)


def reporthook(bar):
"""tqdm progress bar for downloads."""

def hook(b=1, bsize=1, tsize=None):
if tsize is not None:
bar.total = tsize
bar.update(b * bsize - bar.n)

return hook


def get_root(name):
base = "data/"
root = os.path.join(base, name)
os.makedirs(root, exist_ok=True)
return root


def is_prepared(root):
return Path(root).joinpath(".ready").exists()


def mark_prepared(root):
Path(root).joinpath(".ready").touch()


def prompt_download(file_, source, target_dir, content_dir=None):
targetpath = os.path.join(target_dir, file_)
while not os.path.exists(targetpath):
if content_dir is not None and os.path.exists(
os.path.join(target_dir, content_dir)
):
break
print(
"Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
)
if content_dir is not None:
print(
"Or place its content into '{}'.".format(
os.path.join(target_dir, content_dir)
)
)
input("Press Enter when done...")
return targetpath


def download_url(file_, url, target_dir):
targetpath = os.path.join(target_dir, file_)
os.makedirs(target_dir, exist_ok=True)
with tqdm(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
) as bar:
urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
return targetpath


def download_urls(urls, target_dir):
paths = dict()
for fname, url in urls.items():
outpath = download_url(fname, url, target_dir)
paths[fname] = outpath
return paths


def quadratic_crop(x, bbox, alpha=1.0):
"""bbox is xmin, ymin, xmax, ymax"""
im_h, im_w = x.shape[:2]
bbox = np.array(bbox, dtype=np.float32)
bbox = np.clip(bbox, 0, max(im_h, im_w))
center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
l = int(alpha * max(w, h))
l = max(l, 2)

required_padding = -1 * min(
center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
)
required_padding = int(np.ceil(required_padding))
if required_padding > 0:
padding = [
[required_padding, required_padding],
[required_padding, required_padding],
]
padding += [[0, 0]] * (len(x.shape) - 2)
x = np.pad(x, padding, "reflect")
center = center[0] + required_padding, center[1] + required_padding
xmin = int(center[0] - l / 2)
ymin = int(center[1] - l / 2)
return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])


def custom_collate(batch):
r"""source: pytorch 1.9.0, only one modification to original code """

elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))

return custom_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: custom_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
return batch # added
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [custom_collate(samples) for samples in transposed]

raise TypeError(default_collate_err_msg_format.format(elem_type))
import collections
import os
import tarfile
import urllib
import zipfile
from pathlib import Path
import numpy as np
import torch
from taming.data.helper_types import Annotation
#from torch._six import string_classes
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
from tqdm import tqdm
def unpack(path):
if path.endswith("tar.gz"):
with tarfile.open(path, "r:gz") as tar:
tar.extractall(path=os.path.split(path)[0])
elif path.endswith("tar"):
with tarfile.open(path, "r:") as tar:
tar.extractall(path=os.path.split(path)[0])
elif path.endswith("zip"):
with zipfile.ZipFile(path, "r") as f:
f.extractall(path=os.path.split(path)[0])
else:
raise NotImplementedError(
"Unknown file extension: {}".format(os.path.splitext(path)[1])
)
def reporthook(bar):
"""tqdm progress bar for downloads."""
def hook(b=1, bsize=1, tsize=None):
if tsize is not None:
bar.total = tsize
bar.update(b * bsize - bar.n)
return hook
def get_root(name):
base = "data/"
root = os.path.join(base, name)
os.makedirs(root, exist_ok=True)
return root
def is_prepared(root):
return Path(root).joinpath(".ready").exists()
def mark_prepared(root):
Path(root).joinpath(".ready").touch()
def prompt_download(file_, source, target_dir, content_dir=None):
targetpath = os.path.join(target_dir, file_)
while not os.path.exists(targetpath):
if content_dir is not None and os.path.exists(
os.path.join(target_dir, content_dir)
):
break
print(
"Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
)
if content_dir is not None:
print(
"Or place its content into '{}'.".format(
os.path.join(target_dir, content_dir)
)
)
input("Press Enter when done...")
return targetpath
def download_url(file_, url, target_dir):
targetpath = os.path.join(target_dir, file_)
os.makedirs(target_dir, exist_ok=True)
with tqdm(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
) as bar:
urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
return targetpath
def download_urls(urls, target_dir):
paths = dict()
for fname, url in urls.items():
outpath = download_url(fname, url, target_dir)
paths[fname] = outpath
return paths
def quadratic_crop(x, bbox, alpha=1.0):
"""bbox is xmin, ymin, xmax, ymax"""
im_h, im_w = x.shape[:2]
bbox = np.array(bbox, dtype=np.float32)
bbox = np.clip(bbox, 0, max(im_h, im_w))
center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
l = int(alpha * max(w, h))
l = max(l, 2)
required_padding = -1 * min(
center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
)
required_padding = int(np.ceil(required_padding))
if required_padding > 0:
padding = [
[required_padding, required_padding],
[required_padding, required_padding],
]
padding += [[0, 0]] * (len(x.shape) - 2)
x = np.pad(x, padding, "reflect")
center = center[0] + required_padding, center[1] + required_padding
xmin = int(center[0] - l / 2)
ymin = int(center[1] - l / 2)
return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
def custom_collate(batch):
r"""source: pytorch 1.9.0, only one modification to original code """
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return custom_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):#string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: custom_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
return batch # added
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [custom_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))