Skip to content

Commit 3db088d

Browse files
author
wbw520
committed
add mae
1 parent 57df8c7 commit 3db088d

18 files changed

+520
-1823
lines changed

data/cityscapes.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
import torch
44
from PIL import Image
5-
from torch.utils import data
65

76

87
num_classes = 19
@@ -34,7 +33,7 @@ def make_dataset_cityscapes(args, quality, mode):
3433
return items
3534

3635

37-
class CityScapes(data.Dataset):
36+
class CityScapes(torch.utils.data.Dataset):
3837
def __init__(self, args, quality, mode, joint_transform=None, standard_transform=None):
3938
self.imgs = make_dataset_cityscapes(args, quality, mode)
4039
if len(self.imgs) == 0:

data/facade.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
from PIL import Image
3+
import numpy as np
4+
from utils.base_tools import get_name
5+
6+
7+
ignore_label = 255
8+
9+
10+
def prepare_facade_data(args):
11+
items = get_name(args.root + "/translated_data/images")
12+
13+
14+
class Facade(torch.utils.data.Dataset):
15+
def __init__(self, args, mode, joint_transform=None, standard_transform=None):
16+
self.args = args
17+
self.imgs = ""
18+
if len(self.imgs) == 0:
19+
raise RuntimeError('Found 0 images, please check the data set')
20+
21+
self.joint_transform = joint_transform
22+
self.standard_transform = standard_transform
23+
if self.args.use_ignore:
24+
self.id_to_trainid = {6: 255, 7: 255, 8: 255, 9: 255}
25+
26+
def __getitem__(self, index):
27+
img_path, mask_path = self.imgs[index]
28+
img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
29+
30+
mask = np.array(mask)
31+
mask_copy = mask.copy()
32+
if self.args.use_ignore:
33+
for k, v in self.id_to_trainid.items():
34+
mask_copy[mask == k] = v
35+
mask = Image.fromarray(mask_copy.astype(np.uint8))
36+
37+
if self.joint_transform is not None:
38+
img, mask = self.joint_transform(img, mask)
39+
40+
if self.standard_transform is not None:
41+
img = self.standard_transform(img)
42+
43+
return {"images": img, "masks": torch.from_numpy(np.array(mask, dtype=np.int32)).long()}
44+
45+
def __len__(self):
46+
return len(self.imgs)

data/transforms.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
22
import numbers
33
import random
4-
54
from PIL import Image, ImageOps
65
import numpy as np
76

mae_demo.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import sys
2+
import os
3+
import requests
4+
import torch
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
from PIL import Image
8+
from model import mae_model as models_mae
9+
10+
# define the utils
11+
12+
imagenet_mean = np.array([0.485, 0.456, 0.406])
13+
imagenet_std = np.array([0.229, 0.224, 0.225])
14+
15+
16+
def show_image(image, title=''):
17+
# image is [H, W, 3]
18+
assert image.shape[2] == 3
19+
plt.imshow(torch.clamp((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
20+
plt.title(title, fontsize=16)
21+
plt.axis('off')
22+
23+
24+
def prepare_model(chkpt_dir, arch='mae_vit_large_patch8'):
25+
# build mode
26+
model = models_mae.__dict__[arch](img_size=640)
27+
# load model
28+
checkpoint = torch.load(chkpt_dir, map_location='cpu')
29+
msg = model.load_state_dict(checkpoint['model'], strict=False)
30+
print(msg)
31+
return model
32+
33+
34+
def crop_center(pil_img, crop_width, crop_height):
35+
img_width, img_height = pil_img.size
36+
return pil_img.crop(((img_width - crop_width) // 2,
37+
(img_height - crop_height) // 2,
38+
(img_width + crop_width) // 2,
39+
(img_height + crop_height) // 2))
40+
41+
42+
def run_one_image(img, model):
43+
x = torch.tensor(img)
44+
45+
# make it a batch-like
46+
x = x.unsqueeze(dim=0)
47+
x = torch.einsum('nhwc->nchw', x)
48+
49+
# run MAE
50+
loss, y, mask = model(x.float(), mask_ratio=0.75)
51+
y = model.unpatchify(y)
52+
y = torch.einsum('nchw->nhwc', y).detach().cpu()
53+
54+
# visualize the mask
55+
mask = mask.detach()
56+
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0] ** 2 * 3) # (N, H*W, p*p*3)
57+
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
58+
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
59+
60+
x = torch.einsum('nchw->nhwc', x)
61+
62+
# masked image
63+
im_masked = x * (1 - mask)
64+
65+
# MAE reconstruction pasted with visible patches
66+
im_paste = x * (1 - mask) + y * mask
67+
68+
# make the plt figure larger
69+
plt.rcParams['figure.figsize'] = [24, 24]
70+
71+
plt.subplot(1, 4, 1)
72+
show_image(x[0], "original")
73+
74+
plt.subplot(1, 4, 2)
75+
show_image(im_masked[0], "masked")
76+
77+
plt.subplot(1, 4, 3)
78+
show_image(y[0], "reconstruction")
79+
80+
plt.subplot(1, 4, 4)
81+
show_image(im_paste[0], "reconstruction + visible")
82+
83+
plt.show()
84+
85+
86+
# load an image
87+
img = Image.open("/home/wangbowen/DATA/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/test/berlin/berlin_000362_000019_leftImg8bit.png")
88+
img = crop_center(img, 768, 768)
89+
img = img.resize((640, 640))
90+
img = np.array(img) / 255.
91+
92+
# normalize by ImageNet mean and std
93+
img = img - imagenet_mean
94+
img = img / imagenet_std
95+
96+
# plt.rcParams['figure.figsize'] = [5, 5]
97+
# show_image(torch.tensor(img))
98+
# plt.show()
99+
100+
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)
101+
102+
model_mae_gan = prepare_model('save_model/8_640_mae_pre_checkpoint-179.pth', 'mae_vit_large_patch8')
103+
print('Model loaded.')
104+
105+
# torch.manual_seed(2)
106+
print('MAE with extra GAN loss:')
107+
run_one_image(img, model_mae_gan)

0 commit comments

Comments
 (0)