Skip to content

Commit 674b59a

Browse files
author
wbw520
committed
add evaluation
1 parent e295b09 commit 674b59a

10 files changed

+139
-36
lines changed

configs.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ def get_args_parser():
77
# train settings
88
parser.add_argument("--dataset", type=str, default="facade")
99
parser.add_argument("--model_name", type=str, default="PSPNet")
10-
parser.add_argument("--pre_model", type=str, default="ViT-B_16.npz")
11-
parser.add_argument("--batch_size", type=int, default=2,
10+
parser.add_argument("--pre_model", type=str, default="ViT-B_8.npz")
11+
parser.add_argument("--batch_size", type=int, default=4,
1212
help="Number of images sent to the network in one step.")
1313
parser.add_argument("--root", type=str, default="/home/wangbowen/DATA/",
1414
help="Path to the directory containing the image list.")
@@ -27,12 +27,9 @@ def get_args_parser():
2727
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay.")
2828

2929
# VIT settings
30-
parser.add_argument("--patch_size", type=int, default=16, help="define the patch size.")
31-
parser.add_argument("--encoder_embed_dim", type=int, default=768, help="dimension for encoder.")
30+
parser.add_argument("--encoder", type=str, default="vit_base_patch8", help="name for encoder")
3231
parser.add_argument("--decoder_embed_dim", type=int, default=512, help="dimension for decoder.")
33-
parser.add_argument("--encoder_depth", type=int, default=12, help="depth for encoder.")
3432
parser.add_argument("--decoder_depth", type=int, default=2, help="depth for decoder.")
35-
parser.add_argument("--encoder_num_head", type=int, default=12, help="head number for encoder.")
3633
parser.add_argument("--decoder_num_head", type=int, default=8, help="head number for decoder.")
3734

3835
# other settings
@@ -42,11 +39,10 @@ def get_args_parser():
4239
parser.add_argument("--use_ignore", type=bool, default=False)
4340

4441
# # distributed training parameters
45-
parser.add_argument('--num_workers', default=0, type=int)
46-
parser.add_argument("--device", type=str, default='cuda:1',
47-
help="choose gpu device.")
42+
parser.add_argument('--num_workers', default=4, type=int)
43+
parser.add_argument("--device", type=str, default='cuda', help="choose gpu device.")
4844
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
49-
parser.add_argument("--local_rank", type=int)
45+
parser.add_argument("--local_rank", default=-1, type=int)
5046
parser.add_argument('--dist_on_itp', action='store_true')
5147
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
5248

data/facade.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def prepare_facade_data(args):
6969
record.append([roots + "images/" + item, roots + "binary_mask/" + item])
7070

7171
train, other = train_test_split(record, train_size=0.8, random_state=1)
72-
val, test = train_test_split(record, train_size=0.5, random_state=1)
72+
val, test = train_test_split(other, train_size=0.5, random_state=1)
7373
return {"train": train, "val": val, "test": test}
7474

7575

@@ -87,7 +87,6 @@ def __init__(self, args, mode, joint_transform=None, standard_transform=None):
8787

8888
def __getitem__(self, index):
8989
img_path, mask_path = self.imgs[index]
90-
print(img_path)
9190
img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
9291

9392
mask = np.array(mask)

data/facade_data_generation.py

-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ def main():
3434

3535
PT = PolygonTrans()
3636
mask = PT.polygon2mask((h, w), polygons, rectangles)
37-
if np.sum(mask == 10):
38-
print("--------------------")
3937

4038
color_map = PT.id2trainId(mask)
4139

data/get_data_set.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from data.loader_tools import get_joint_transformations, get_standard_transformations, get_joint_transformations_val
44

55

6-
def get_data(args):
6+
def get_data(args, evaluation_setting=None):
77
joint_transformations = get_joint_transformations(args)
88
joint_transformations_val = get_joint_transformations_val(args)
99
standard_transformations = get_standard_transformations()
@@ -18,7 +18,13 @@ def get_data(args):
1818
elif args.dataset == "facade":
1919
train_set = facade.Facade(args, 'train', joint_transform=joint_transformations,
2020
standard_transform=standard_transformations)
21-
val_set = facade.Facade(args, 'val', joint_transform=joint_transformations_val,
21+
22+
if evaluation_setting is not None:
23+
current_set = "test"
24+
else:
25+
current_set = "val"
26+
27+
val_set = facade.Facade(args, 'test', joint_transform=joint_transformations_val,
2228
standard_transform=standard_transformations)
2329
ignore_index = facade.ignore_label
2430
args.num_classes = facade.num_classes

evaluation.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import argparse
2+
import torch.backends.cudnn as cudnn
3+
import torch
4+
from data.get_data_set import get_data
5+
import utils2.misc as misc
6+
from configs import get_args_parser
7+
from utils.engine import evaluation_none_training
8+
from torch.utils.data import DataLoader
9+
from model.get_model import model_generation
10+
import os
11+
12+
13+
def main():
14+
# distribution
15+
misc.init_distributed_mode(args)
16+
device = torch.device(args.device)
17+
cudnn.benchmark = True
18+
train_set, test_set, ignore_index = get_data(args)
19+
model = model_generation(args)
20+
model.to(device)
21+
22+
if args.model_name == "Segmenter":
23+
save_name = args.model_name + "_" + args.encoder
24+
else:
25+
save_name = args.model_name
26+
27+
checkpoint = torch.load(args.output_dir + args.dataset + "_" + save_name + ".pt", map_location="cuda:0")
28+
model.load_state_dict(checkpoint, strict=True)
29+
model.eval()
30+
print("load trained model finished.")
31+
32+
sampler_val = torch.utils.data.SequentialSampler(test_set)
33+
val_loader = DataLoader(test_set, batch_size=args.batch_size, sampler=sampler_val, num_workers=args.num_workers, shuffle=False)
34+
evaluation_none_training(args, model, val_loader, device)
35+
36+
37+
if __name__ == '__main__':
38+
os.makedirs('demo/', exist_ok=True)
39+
parser = argparse.ArgumentParser('model training and evaluation script', parents=[get_args_parser()])
40+
args = parser.parse_args()
41+
img_path = "/home/wangbowen/DATA/Facade/translated_data/images/IMG_1287.png"
42+
main()

inference.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
11
import argparse
22
import torch.backends.cudnn as cudnn
33
import torch
4+
from data.get_data_set import get_data
45
from PIL import Image
56
from data.loader_tools import get_standard_transformations
67
import utils2.misc as misc
8+
from utils.base_tools import get_name
79
from configs import get_args_parser
810
from model.get_model import model_generation
911
from utils.engine import inference_sliding
10-
from data.cityscapes import ColorTransition
1112
from data.facade import PolygonTrans
1213
import matplotlib.pyplot as plt
1314
import numpy as np
1415
import os
1516

1617

17-
def get_name(root, mode_folder=True):
18-
for root, dirs, file in os.walk(root):
19-
if mode_folder:
20-
return dirs
21-
else:
22-
return file
23-
24-
2518
def show_single(image, location=None, save=False):
2619
# show single image
2720
image = np.array(image, dtype=np.uint8)
@@ -40,9 +33,16 @@ def main():
4033
misc.init_distributed_mode(args)
4134
device = torch.device(args.device)
4235
cudnn.benchmark = True
36+
train_set, val_set, ignore_index = get_data(args)
4337
model = model_generation(args)
4438
model.to(device)
45-
checkpoint = torch.load(args.output_dir + args.dataset + "_" + args.model_name + ".pt", map_location="cuda:1")
39+
40+
if args.model_name == "Segmenter":
41+
save_name = args.model_name + "_" + args.encoder
42+
else:
43+
save_name = args.model_name
44+
45+
checkpoint = torch.load(args.output_dir + args.dataset + "_" + save_name + ".pt", map_location="cuda:0")
4646
model.load_state_dict(checkpoint, strict=True)
4747
model.eval()
4848

@@ -59,6 +59,5 @@ def main():
5959
os.makedirs('demo/', exist_ok=True)
6060
parser = argparse.ArgumentParser('model training and evaluation script', parents=[get_args_parser()])
6161
args = parser.parse_args()
62-
args.num_classes = 10
63-
img_path = "/home/wangbowen/DATA/Facade/zhao_translated_data/images/IMG_1282.jpg"
62+
img_path = "/home/wangbowen/DATA/Facade/translated_data/images/IMG_1287.png"
6463
main()

model/segmenter.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from model.vit_model import vit_encoder
1+
import model.vit_model as vit
22
from model.segmenter_decoder import sg_vit_mask_decoder
33
import torch.nn as nn
44
import torch
55
import os
66
import torch.nn.functional as F
7+
import re
78
from timm.models.helpers import load_checkpoint
89

910

@@ -43,9 +44,25 @@ def get_attention_map_dec(self, im, layer_id):
4344
return self.decoder.get_attention_map(x, layer_id)
4445

4546

46-
def create_segmenter(args):
47-
encoder = vit_encoder(img_size=args.crop_size[0], patch_size=args.patch_size, embed_dim=args.encoder_embed_dim, depth=args.encoder_depth, num_heads=args.encoder_num_head)
47+
def set_decoder_parameter(name):
48+
if "small" in name:
49+
encoder_dim = 384
50+
elif "base" in name:
51+
encoder_dim = 768
52+
elif "large" in name:
53+
encoder_dim = 1024
54+
elif "huge" in name:
55+
encoder_dim = 1280
56+
else:
57+
raise "type of encoder is not defined."
58+
59+
patch_size = re.findall("\d+", name)
4860

61+
return encoder_dim, int(patch_size[0])
62+
63+
64+
def create_segmenter(args):
65+
encoder = vit.__dict__[args.encoder](img_size=args.crop_size[0])
4966
if "mae" not in args.pre_model:
5067
print("load pre-model trained by ImageNet")
5168
load_checkpoint(encoder, args.output_dir + args.pre_model)
@@ -64,7 +81,8 @@ def create_segmenter(args):
6481

6582
print("load pre-trained weight from: ", args.pre_model)
6683

67-
decoder = sg_vit_mask_decoder(patch_size=args.patch_size, encoder_embed_dim=args.encoder_embed_dim,
84+
encoder_embed_dim, patch_size = set_decoder_parameter(args.encoder)
85+
decoder = sg_vit_mask_decoder(patch_size=patch_size, encoder_embed_dim=encoder_embed_dim,
6886
decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.decoder_depth, decoder_num_heads=args.decoder_num_head, n_cls=args.num_classes)
6987
model = Segmenter(encoder, decoder)
7088

model/vit_model.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,28 @@ def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
7878
return posemb
7979

8080

81-
def vit_encoder(**kwargs):
82-
model = VisionTransformer(mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=0, **kwargs)
81+
def vit_base_patch8(**kwargs):
82+
model = VisionTransformer(patch_size=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
83+
norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=0,
84+
**kwargs)
8385
return model
86+
87+
88+
def vit_base_patch16(**kwargs):
89+
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
90+
norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=0,
91+
**kwargs)
92+
return model
93+
94+
95+
def vit_base_patch32(**kwargs):
96+
model = VisionTransformer(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
97+
norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=0,
98+
**kwargs)
99+
return model
100+
101+
102+
# set recommended archs
103+
vit_base_patch8 = vit_base_patch8
104+
vit_base_patch16 = vit_base_patch16
105+
vit_base_patch32 = vit_base_patch32

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def main():
2828
model_without_ddp = model
2929

3030
if args.distributed:
31-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
31+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
3232
model_without_ddp = model.module
3333
sampler_train = DistributedSampler(train_set)
3434
sampler_val = DistributedSampler(val_set, shuffle=False)

utils/engine.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def evaluation(args, best_record, epoch, model, model_without_ddp, val_loader, c
8686
best_record['acc_cls'] = acc_cls
8787
best_record['mean_iou'] = mean_iou
8888
if args.output_dir:
89-
torch.save(model_without_ddp.state_dict(), args.output_dir + args.dataset + "_" + args.model_name + ".pt")
89+
if args.model_name == "Segmenter":
90+
save_name = args.model_name + "_" + args.encoder
91+
else:
92+
save_name = args.model_name
93+
torch.save(model_without_ddp.state_dict(), args.output_dir + args.dataset + "_" + save_name + ".pt")
9094

9195
print('-----------------------------------------------------------------------------------------------------------')
9296
print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iou %.5f]' % (
@@ -99,6 +103,25 @@ def evaluation(args, best_record, epoch, model, model_without_ddp, val_loader, c
99103
print('-----------------------------------------------------------------------------------------------------------')
100104

101105

106+
@torch.no_grad()
107+
def evaluation_none_training(args, model, val_loader, device):
108+
model.eval()
109+
iou = IouCal(args)
110+
for i_batch, data in enumerate(val_loader):
111+
if i_batch % 5 == 0:
112+
print(str(i_batch) + "/" + str(len(val_loader)))
113+
inputs = data["images"].to(device, dtype=torch.float32)
114+
mask = data["masks"].to(device, dtype=torch.int64)
115+
116+
pred, full_pred = inference_sliding(args, model, inputs)
117+
iou.evaluate(pred, mask)
118+
119+
acc, acc_cls, mean_iou = iou.iou_demo()
120+
print("acc:", acc)
121+
print("acc_cls", acc_cls)
122+
print("mean_iou", mean_iou)
123+
124+
102125
@torch.no_grad()
103126
def inference_sliding(args, model, image):
104127
image_size = image.size()

0 commit comments

Comments
 (0)