1
+ import sys
2
+ import os
3
+ import requests
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib .pyplot as plt
7
+ from PIL import Image
8
+ from model import mae_model as models_mae
9
+
10
+ # define the utils
11
+
12
+ imagenet_mean = np .array ([0.485 , 0.456 , 0.406 ])
13
+ imagenet_std = np .array ([0.229 , 0.224 , 0.225 ])
14
+
15
+
16
+ def show_image (image , title = '' ):
17
+ # image is [H, W, 3]
18
+ assert image .shape [2 ] == 3
19
+ plt .imshow (torch .clamp ((image * imagenet_std + imagenet_mean ) * 255 , 0 , 255 ).int ())
20
+ plt .title (title , fontsize = 16 )
21
+ plt .axis ('off' )
22
+
23
+
24
+ def prepare_model (chkpt_dir , arch = 'mae_vit_large_patch8' ):
25
+ # build mode
26
+ model = models_mae .__dict__ [arch ](img_size = 640 )
27
+ # load model
28
+ checkpoint = torch .load (chkpt_dir , map_location = 'cpu' )
29
+ msg = model .load_state_dict (checkpoint ['model' ], strict = False )
30
+ print (msg )
31
+ return model
32
+
33
+
34
+ def crop_center (pil_img , crop_width , crop_height ):
35
+ img_width , img_height = pil_img .size
36
+ return pil_img .crop (((img_width - crop_width ) // 2 ,
37
+ (img_height - crop_height ) // 2 ,
38
+ (img_width + crop_width ) // 2 ,
39
+ (img_height + crop_height ) // 2 ))
40
+
41
+
42
+ def run_one_image (img , model ):
43
+ x = torch .tensor (img )
44
+
45
+ # make it a batch-like
46
+ x = x .unsqueeze (dim = 0 )
47
+ x = torch .einsum ('nhwc->nchw' , x )
48
+
49
+ # run MAE
50
+ loss , y , mask = model (x .float (), mask_ratio = 0.75 )
51
+ y = model .unpatchify (y )
52
+ y = torch .einsum ('nchw->nhwc' , y ).detach ().cpu ()
53
+
54
+ # visualize the mask
55
+ mask = mask .detach ()
56
+ mask = mask .unsqueeze (- 1 ).repeat (1 , 1 , model .patch_embed .patch_size [0 ] ** 2 * 3 ) # (N, H*W, p*p*3)
57
+ mask = model .unpatchify (mask ) # 1 is removing, 0 is keeping
58
+ mask = torch .einsum ('nchw->nhwc' , mask ).detach ().cpu ()
59
+
60
+ x = torch .einsum ('nchw->nhwc' , x )
61
+
62
+ # masked image
63
+ im_masked = x * (1 - mask )
64
+
65
+ # MAE reconstruction pasted with visible patches
66
+ im_paste = x * (1 - mask ) + y * mask
67
+
68
+ # make the plt figure larger
69
+ plt .rcParams ['figure.figsize' ] = [24 , 24 ]
70
+
71
+ plt .subplot (1 , 4 , 1 )
72
+ show_image (x [0 ], "original" )
73
+
74
+ plt .subplot (1 , 4 , 2 )
75
+ show_image (im_masked [0 ], "masked" )
76
+
77
+ plt .subplot (1 , 4 , 3 )
78
+ show_image (y [0 ], "reconstruction" )
79
+
80
+ plt .subplot (1 , 4 , 4 )
81
+ show_image (im_paste [0 ], "reconstruction + visible" )
82
+
83
+ plt .show ()
84
+
85
+
86
+ # load an image
87
+ img = Image .open ("/home/wangbowen/DATA/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/test/berlin/berlin_000362_000019_leftImg8bit.png" )
88
+ img = crop_center (img , 768 , 768 )
89
+ img = img .resize ((640 , 640 ))
90
+ img = np .array (img ) / 255.
91
+
92
+ # normalize by ImageNet mean and std
93
+ img = img - imagenet_mean
94
+ img = img / imagenet_std
95
+
96
+ # plt.rcParams['figure.figsize'] = [5, 5]
97
+ # show_image(torch.tensor(img))
98
+ # plt.show()
99
+
100
+ # This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)
101
+
102
+ model_mae_gan = prepare_model ('save_model/8_640_mae_pre_checkpoint-179.pth' , 'mae_vit_large_patch8' )
103
+ print ('Model loaded.' )
104
+
105
+ # torch.manual_seed(2)
106
+ print ('MAE with extra GAN loss:' )
107
+ run_one_image (img , model_mae_gan )
0 commit comments