Skip to content

Commit

Permalink
add databases to inference script
Browse files Browse the repository at this point in the history
  • Loading branch information
ablattmann committed Jul 25, 2022
1 parent c8a6b66 commit 78ef9ca
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 50 deletions.
71 changes: 71 additions & 0 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat
import kornia


from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test

Expand Down Expand Up @@ -129,3 +133,70 @@ def forward(self,x):

def encode(self, x):
return self(x)


class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
self.normalize = normalize

def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False

def forward(self, text):
tokens = clip.tokenize(text).to(self.device)
z = self.model.encode_text(tokens)
if self.normalize:
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
return z

def encode(self, text):
z = self(text)
if z.ndim==2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z


class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)

self.antialias = antialias

self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)

def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x

def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))

111 changes: 110 additions & 1 deletion ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,113 @@ def get_obj_from_str(string, reload=False):
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
return getattr(importlib.import_module(module, package=None), cls)

def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance

# run prefetching
if idx_to_fn:
res = func(data,worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")


def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray",cpu_intensive=True,use_worker_id=False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)

if cpu_intensive:
Q = mp.Queue(1000)
proc = mp.Process
else:
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i : i + step] for i in range(0, len(data), step)]
)
]
processes = []
for i in range(n_proc):
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
processes += [p]

# start processes
print(f"Start prefetching...")
import time

start = time.time()
gather_res = [[] for _ in range(n_proc)]
try:
for p in processes:
p.start()

k = 0
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]

except Exception as e:
print("Exception: ", e)
for p in processes:
p.terminate()

raise e
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")

if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)

# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
out = []
for r in gather_res:
out.extend(r)
return out
else:
return gather_res
Loading

0 comments on commit 78ef9ca

Please sign in to comment.