Skip to content

Commit ee44604

Browse files
committed
extract submodel script and segmentation stage example config
1 parent 1bbc027 commit ee44604

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

configs/sflckr_cond_stage.yaml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
model:
2+
base_learning_rate: 4.5e-06
3+
target: taming.models.vqgan.VQSegmentationModel
4+
params:
5+
embed_dim: 256
6+
n_embed: 1024
7+
image_key: "segmentation"
8+
n_labels: 182
9+
ddconfig:
10+
double_z: false
11+
z_channels: 256
12+
resolution: 256
13+
in_channels: 182
14+
out_ch: 182
15+
ch: 128
16+
ch_mult:
17+
- 1
18+
- 1
19+
- 2
20+
- 2
21+
- 4
22+
num_res_blocks: 2
23+
attn_resolutions:
24+
- 16
25+
dropout: 0.0
26+
27+
lossconfig:
28+
target: taming.modules.losses.segmentation.BCELossWithQuant
29+
params:
30+
codebook_weight: 1.0
31+
32+
data:
33+
target: cutlit.DataModuleFromConfig
34+
params:
35+
batch_size: 12
36+
train:
37+
target: taming.data.sflckr.Examples # adjust
38+
params:
39+
size: 256
40+
validation:
41+
target: taming.data.sflckr.Examples # adjust
42+
params:
43+
size: 256

scripts/extract_submodel.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
import sys
3+
4+
if __name__ == "__main__":
5+
inpath = sys.argv[1]
6+
outpath = sys.argv[2]
7+
submodel = "cond_stage_model"
8+
if len(sys.argv) > 3:
9+
submodel = sys.argv[3]
10+
11+
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
12+
13+
sd = torch.load(inpath, map_location="cpu")
14+
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
15+
for k,v in sd["state_dict"].items()
16+
if k.startswith("cond_stage_model"))}
17+
torch.save(new_sd, outpath)

taming/modules/losses/segmentation.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
5+
class BCELoss(nn.Module):
6+
def forward(self, prediction, target):
7+
loss = F.binary_cross_entropy_with_logits(prediction,target)
8+
return loss, {}
9+
10+
11+
class BCELossWithQuant(nn.Module):
12+
def __init__(self, codebook_weight=1.):
13+
super().__init__()
14+
self.codebook_weight = codebook_weight
15+
16+
def forward(self, qloss, target, prediction, split):
17+
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18+
loss = bce_loss + self.codebook_weight*qloss
19+
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20+
"{}/bce_loss".format(split): bce_loss.detach().mean(),
21+
"{}/quant_loss".format(split): qloss.detach().mean()
22+
}

0 commit comments

Comments
 (0)