Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 140 additions & 55 deletions core/__init__.py

Large diffs are not rendered by default.

44 changes: 25 additions & 19 deletions core/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,49 @@
import json
import os
import subprocess

import yaml
import os
from .bucketeer import Bucketeer

class MultiFilter():

class MultiFilter:
def __init__(self, rules, default=False):
self.rules = rules
self.default = default

def __call__(self, x):
try:
x_json = x['json']
x_json = x["json"]
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
x_json = json.loads(x_json)
validations = []
for k, r in self.rules.items():
if isinstance(k, tuple):
v = r(*[x_json[kv] for kv in k])
else:
v = r(x_json[k])
v = (
r(*[x_json[kv] for kv in k])
if isinstance(k, tuple)
else r(x_json[k])
)
validations.append(v)
return all(validations)
except Exception:
return False

class MultiGetter():

class MultiGetter:
def __init__(self, rules):
self.rules = rules

def __call__(self, x_json):
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
x_json = json.loads(x_json)
outputs = []
for k, r in self.rules.items():
if isinstance(k, tuple):
v = r(*[x_json[kv] for kv in k])
else:
v = r(x_json[k])
v = r(*[x_json[kv] for kv in k]) if isinstance(k, tuple) else r(x_json[k])
outputs.append(v)
if len(outputs) == 1:
outputs = outputs[0]
return outputs


def setup_webdataset_path(paths, cache_path=None):
if cache_path is None or not os.path.exists(cache_path):
tar_paths = []
Expand All @@ -54,15 +55,20 @@ def setup_webdataset_path(paths, cache_path=None):
tar_paths.append(path)
continue
bucket = "/".join(path.split("/")[:3])
result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
files = result.stdout.decode('utf-8').split()
result = subprocess.run(
[f"aws s3 ls {path} --recursive | awk '{{print $4}}'"],
stdout=subprocess.PIPE,
shell=True,
check=True,
)
files = result.stdout.decode("utf-8").split()
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
tar_paths += files

with open(cache_path, 'w', encoding='utf-8') as outfile:
with open(cache_path, "w", encoding="utf-8") as outfile:
yaml.dump(tar_paths, outfile, default_flow_style=False)
else:
with open(cache_path, 'r', encoding='utf-8') as file:
with open(cache_path, "r", encoding="utf-8") as file:
tar_paths = yaml.safe_load(file)

tar_paths_str = ",".join([f"{p}" for p in tar_paths])
Expand Down
93 changes: 66 additions & 27 deletions core/data/bucketeer.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,57 @@
import math

import numpy as np
import torch
import torchvision
import numpy as np
from torchtools.transforms import SmartCrop
import math

class Bucketeer():
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
assert crop_mode in ['center', 'random', 'smart']

class Bucketeer:
def __init__(
self,
dataloader,
density=256 * 256,
factor=8,
ratios=None,
reverse_list=True,
randomize_p=0.3,
randomize_q=0.2,
crop_mode="random",
p_random_ratio=0.0,
interpolate_nearest=False,
):
if ratios is None:
ratios = [1, 1 / 2, 3 / 4, 3 / 5, 4 / 5, 6 / 9, 9 / 16]
assert crop_mode in ["center", "random", "smart"]
self.crop_mode = crop_mode
self.ratios = ratios
if reverse_list:
for r in list(ratios):
if 1/r not in self.ratios:
self.ratios.append(1/r)
self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios]
if 1 / r not in self.ratios:
self.ratios.append(1 / r)
self.sizes = [
(
int(((density / r) ** 0.5 // factor) * factor),
int(((density * r) ** 0.5 // factor) * factor),
)
for r in ratios
]
self.batch_size = dataloader.batch_size
self.iterator = iter(dataloader)
self.buckets = {s: [] for s in self.sizes}
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
self.smartcrop = (
SmartCrop(int(density**0.5), randomize_p, randomize_q)
if self.crop_mode == "smart"
else None
)
self.p_random_ratio = p_random_ratio
self.interpolate_nearest = interpolate_nearest

def get_available_batch(self):
for b in self.buckets:
if len(self.buckets[b]) >= self.batch_size:
batch = self.buckets[b][:self.batch_size]
self.buckets[b] = self.buckets[b][self.batch_size:]
batch = self.buckets[b][: self.batch_size]
self.buckets[b] = self.buckets[b][self.batch_size :]
return batch
return None

Expand All @@ -34,39 +60,52 @@ def get_closest_size(self, x):
best_size_idx = np.random.randint(len(self.ratios))
else:
w, h = x.size(-1), x.size(-2)
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
best_size_idx = np.argmin([abs(w / h - r) for r in self.ratios])
return self.sizes[best_size_idx]

def get_resize_size(self, orig_size, tgt_size):
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
resize_size = max(alt_min, min(tgt_size))
if (tgt_size[1] / tgt_size[0] - 1) * (orig_size[1] / orig_size[0] - 1) >= 0:
alt_min = int(math.ceil(max(tgt_size) * min(orig_size) / max(orig_size)))
return max(alt_min, min(tgt_size))
else:
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
resize_size = max(alt_max, max(tgt_size))
return resize_size
alt_max = int(math.ceil(min(tgt_size) * max(orig_size) / min(orig_size)))
return max(alt_max, max(tgt_size))

def __next__(self):
batch = self.get_available_batch()
while batch is None:
elements = next(self.iterator)
for dct in elements:
img = dct['images']
img = dct["images"]
size = self.get_closest_size(img)
resize_size = self.get_resize_size(img.shape[-2:], size)
if self.interpolate_nearest:
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
img = torchvision.transforms.functional.resize(
img,
resize_size,
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
)
else:
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
if self.crop_mode == 'center':
img = torchvision.transforms.functional.resize(
img,
resize_size,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
if self.crop_mode == "center":
img = torchvision.transforms.functional.center_crop(img, size)
elif self.crop_mode == 'random':
elif self.crop_mode == "random":
img = torchvision.transforms.RandomCrop(size)(img)
elif self.crop_mode == 'smart':
elif self.crop_mode == "smart":
self.smartcrop.output_size = size
img = self.smartcrop(img)
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
self.buckets[size].append(
{**{"images": img}, **{k: dct[k] for k in dct if k != "images"}}
)
batch = self.get_available_batch()

out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
out = {k: [batch[i][k] for i in range(len(batch))] for k in batch[0]}
return {
k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o
for k, o in out.items()
}
24 changes: 12 additions & 12 deletions core/scripts/cli.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import sys
import argparse
from .. import WarpCore
from .. import templates
import sys

from .. import WarpCore, templates


def template_init(args):
return ''''
return """'


'''.strip()
""".strip()


def init_template(args):
parser = argparse.ArgumentParser(description='WarpCore template init tool')
parser.add_argument('-t', '--template', type=str, default='WarpCore')
parser = argparse.ArgumentParser(description="WarpCore template init tool")
parser.add_argument("-t", "--template", type=str, default="WarpCore")
args = parser.parse_args(args)

if args.template == 'WarpCore':
if args.template == "WarpCore":
template_cls = WarpCore
else:
try:
Expand All @@ -28,14 +28,14 @@ def init_template(args):

def main():
if len(sys.argv) < 2:
print('Usage: core <command>')
print("Usage: core <command>")
sys.exit(1)
if sys.argv[1] == 'init':
if sys.argv[1] == "init":
init_template(sys.argv[2:])
else:
print('Unknown command')
print("Unknown command")
sys.exit(1)


if __name__ == '__main__':
if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion core/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .diffusion import DiffusionCore
Loading