-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathutil.py
78 lines (58 loc) · 1.91 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.sampling import *
def init_model(cfgs):
model_cfg = OmegaConf.load(cfgs.model_cfg_path)
ckpt = cfgs.load_ckpt_path
model = instantiate_from_config(model_cfg.model)
model.init_from_ckpt(ckpt)
if cfgs.type == "train":
model.train()
else:
model.to(torch.device("cuda", index=cfgs.gpu))
model.eval()
model.freeze()
return model
def init_sampling(cfgs):
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": cfgs.scale[0]},
}
sampler = EulerEDMSampler(
num_steps=cfgs.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=torch.device("cuda", index=cfgs.gpu)
)
return sampler
def deep_copy(batch):
c_batch = {}
for key in batch:
if isinstance(batch[key], torch.Tensor):
c_batch[key] = torch.clone(batch[key])
elif isinstance(batch[key], (tuple, list)):
c_batch[key] = batch[key].copy()
else:
c_batch[key] = batch[key]
return c_batch
def prepare_batch(cfgs, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
batch_uc = deep_copy(batch)
if "ntxt" in batch:
batch_uc["txt"] = batch["ntxt"]
else:
batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))]
if "label" in batch:
batch_uc["label"] = ["" for _ in range(len(batch["label"]))]
return batch, batch_uc