Skip to content

Commit 8cf2bb9

Browse files
author
wbw520
committed
modify some content
1 parent 2a2b262 commit 8cf2bb9

16 files changed

+255
-140
lines changed

configs.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@ def get_args_parser():
55
parser = argparse.ArgumentParser(description="PSP-Net Network", add_help=False)
66

77
# train settings
8-
parser.add_argument("--dataset", type=str, default="cityscapes")
8+
parser.add_argument("--dataset", type=str, default="facade")
99
parser.add_argument("--model_name", type=str, default="Segmenter")
10-
parser.add_argument("--pre_model", type=str, default="mae_pre_epoch99_crop640_patch8_ed768_depth12_head12.pt")
11-
parser.add_argument("--batch_size", type=int, default=1,
10+
parser.add_argument("--pre_model", type=str, default="ViT-B_16.npz")
11+
parser.add_argument("--batch_size", type=int, default=4,
1212
help="Number of images sent to the network in one step.")
13-
parser.add_argument("--root", type=str, default="/home/wangbowen/DATA/cityscapes",
13+
parser.add_argument("--root", type=str, default="/home/wangbowen/DATA/",
1414
help="Path to the directory containing the image list.")
15+
parser.add_argument("--setting_size", type=int, default=[1024, 2048],
16+
help="original size of data set image.")
1517
parser.add_argument("--crop_size", type=int, default=[640, 640],
1618
help="crop size for training and inference slice.")
1719
parser.add_argument("--stride_rate", type=float, default=0.5, help="stride ratio.")
18-
parser.add_argument("--num_epoch", type=int, default=200, help="Number of training steps.")
19-
parser.add_argument("--num_classes", type=int, default=19, help="Number of class for dataset.")
20+
parser.add_argument("--num_epoch", type=int, default=60, help="Number of training steps.")
2021
parser.add_argument('--accum_iter', default=1, type=int,
2122
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
2223

@@ -26,7 +27,7 @@ def get_args_parser():
2627
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay.")
2728

2829
# VIT settings
29-
parser.add_argument("--patch_size", type=int, default=8, help="define the patch size.")
30+
parser.add_argument("--patch_size", type=int, default=16, help="define the patch size.")
3031
parser.add_argument("--encoder_embed_dim", type=int, default=768, help="dimension for encoder.")
3132
parser.add_argument("--decoder_embed_dim", type=int, default=512, help="dimension for decoder.")
3233
parser.add_argument("--encoder_depth", type=int, default=12, help="depth for encoder.")
@@ -38,10 +39,11 @@ def get_args_parser():
3839
parser.add_argument("--save_summary", type=str, default="save_model")
3940
parser.add_argument("--print_freq", type=str, default=5, help="print frequency.")
4041
parser.add_argument('--output_dir', default='save_model/', help='path where to save, empty for no saving')
42+
parser.add_argument("--use_ignore", type=bool, default=False)
4143

4244
# # distributed training parameters
4345
parser.add_argument('--num_workers', default=4, type=int)
44-
parser.add_argument("--device", type=str, default='cuda',
46+
parser.add_argument("--device", type=str, default='cuda:1',
4547
help="choose gpu device.")
4648
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
4749
parser.add_argument("--local_rank", type=int)

data/facade.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,79 @@
22
from PIL import Image
33
import numpy as np
44
from utils.base_tools import get_name
5+
import json
6+
from sklearn.model_selection import train_test_split
7+
import cv2
58

69

710
ignore_label = 255
11+
num_classes = 10
12+
colors = {0: [0, 0, 0], 1: [70, 70, 70], 2: [250, 170, 30], 3: [70, 130, 180], 4: [0, 60, 100], 5: [153, 153, 153],
13+
6: [107, 142, 35], 7: [255, 0, 0], 8: [0, 0, 142], 9: [220, 220, 0]}
14+
15+
16+
class PolygonTrans():
17+
def __init__(self):
18+
self.binary = {"building": 1, "window": 2, "sky": 3, "roof": 4, "door": 5, "tree": 6, "people": 7, "car": 8, "sign": 9}
19+
self.overlap_order = ["sky", "building", "roof", "door", "window", "tree", "people", "car", "sign"]
20+
21+
def polygon2mask(self, img_size, polygons, rectangles):
22+
mask = np.zeros(img_size, dtype=np.uint8)
23+
for cat in self.overlap_order:
24+
polygon = polygons[cat]
25+
cv2.fillPoly(mask, polygon, color=self.binary[cat])
26+
rectangle = rectangles[cat]
27+
for ret in rectangle:
28+
x1, y1 = ret[0]
29+
x2, y2 = ret[1]
30+
mask[y1:y2, x1:x2] = self.binary[cat]
31+
return mask
32+
33+
# translate label_id to color img
34+
def id2trainId(self, label):
35+
w, h = label.shape
36+
label_copy = np.zeros((w, h, 3), dtype=np.uint8)
37+
for index, color in colors.items():
38+
label_copy[label == index] = color
39+
return label_copy.astype(np.uint8)
40+
41+
42+
def read_json(file_name):
43+
record = {"building": [], "window": [], "sky": [], "roof": [], "door": [], "tree": [], "people": [], "car": [], "sign": []}
44+
record_rectangle = {"building": [], "window": [], "sky": [], "roof": [], "door": [], "tree": [], "people": [], "car": [], "sign": []}
45+
with open(file_name, "r") as load_polygon:
46+
data = json.load(load_polygon)
47+
48+
data = data["shapes"]
49+
for item in data:
50+
label = item["label"]
51+
points = item["points"]
52+
shape = item["shape_type"]
53+
if label not in record:
54+
continue
55+
56+
if shape == "rectangle":
57+
record_rectangle[label].append(np.array(points, dtype=np.int32))
58+
else:
59+
record[label].append(np.array(points, dtype=np.int32))
60+
return record, record_rectangle
861

962

1063
def prepare_facade_data(args):
11-
items = get_name(args.root + "/translated_data/images")
64+
roots = args.root + "Facade/wang_translated_data/"
65+
items = get_name(roots + "images", mode_folder=False)
66+
record = []
67+
for item in items:
68+
record.append([roots + "images/" + item, roots + "binary_mask/" + item])
69+
70+
train, val = train_test_split(record, train_size=0.9, random_state=1)
71+
return {"train": train, "val": val}
1272

1373

1474
class Facade(torch.utils.data.Dataset):
1575
def __init__(self, args, mode, joint_transform=None, standard_transform=None):
1676
self.args = args
17-
self.imgs = ""
77+
self.imgs = prepare_facade_data(args)[mode]
1878
if len(self.imgs) == 0:
1979
raise RuntimeError('Found 0 images, please check the data set')
2080

data/facade_data_generation.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import cv2
2+
from utils.base_tools import get_name
3+
from data.facade import read_json, PolygonTrans
4+
import os
5+
from PIL import Image
6+
import shutil
7+
8+
9+
def main():
10+
item_list = get_name(root + name, mode_folder=False)
11+
image_list = []
12+
for item in item_list:
13+
name_ = item.split(".")[0]
14+
if name_ not in image_list:
15+
image_list.append(name_)
16+
17+
for img in image_list:
18+
print(img)
19+
json_root = root + name + "/" + img + ".json"
20+
if not os.path.exists(json_root):
21+
print("file not exist: ", json_root)
22+
continue
23+
24+
polygons, rectangles = read_json(json_root)
25+
if "IMG_E" in img:
26+
suffix = ".JPG"
27+
else:
28+
suffix = ".jpg"
29+
30+
image = cv2.imread(root + name + "/" + img + suffix, cv2.IMREAD_COLOR)
31+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32+
h, w, c = image.shape
33+
34+
PT = PolygonTrans()
35+
mask = PT.polygon2mask((h, w), polygons, rectangles)
36+
color_map = PT.id2trainId(mask)
37+
38+
image = Image.fromarray(image)
39+
mask = Image.fromarray(mask)
40+
color_map = Image.fromarray(color_map)
41+
42+
image.save(root_img + "/" + img + ".jpg")
43+
mask.save(root_binary_mask + "/" + img + ".jpg")
44+
color_map.save(root_color_mask + "/" + img + ".jpg")
45+
46+
47+
if __name__ == '__main__':
48+
root = "/home/wangbowen/DATA/Facade/"
49+
name = "zhao"
50+
use_predict = ""
51+
shutil.rmtree(root + name + "translated_data", ignore_errors=True)
52+
root_img = root + name + "_translated_data/images"
53+
root_color_mask = root + name + "_translated_data/color_mask"
54+
root_binary_mask = root + name + "_translated_data/binary_mask"
55+
os.makedirs(root_img, exist_ok=True)
56+
os.makedirs(root_color_mask, exist_ok=True)
57+
os.makedirs(root_binary_mask, exist_ok=True)
58+
main()

data/get_data_set.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from data import cityscapes
2+
from data import facade
3+
from data.loader_tools import get_joint_transformations, get_standard_transformations, get_joint_transformations_val
4+
5+
6+
def get_data(args):
7+
joint_transformations = get_joint_transformations(args)
8+
joint_transformations_val = get_joint_transformations_val(args)
9+
standard_transformations = get_standard_transformations()
10+
11+
if args.dataset == "cityscapes":
12+
train_set = cityscapes.CityScapes(args, 'fine', 'train', joint_transform=joint_transformations,
13+
standard_transform=standard_transformations)
14+
val_set = cityscapes.CityScapes(args, 'fine', 'val', joint_transform=None,
15+
standard_transform=standard_transformations)
16+
ignore_index = cityscapes.ignore_label
17+
args.num_classes = cityscapes.num_classes
18+
elif args.dataset == "facade":
19+
train_set = facade.Facade(args, 'train', joint_transform=joint_transformations,
20+
standard_transform=standard_transformations)
21+
val_set = facade.Facade(args, 'val', joint_transform=joint_transformations_val,
22+
standard_transform=standard_transformations)
23+
ignore_index = facade.ignore_label
24+
args.num_classes = facade.num_classes
25+
else:
26+
raise "dataset name error !"
27+
28+
return train_set, val_set, ignore_index

data/loader_tools.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44

55
def get_joint_transformations(args):
66
aug_list = [
7-
joint_transforms.RandomSized(),
7+
joint_transforms.Resize(args),
8+
# joint_transforms.RandomSized(),
89
# joint_transforms.RandomRotate(10),
910
joint_transforms.RandomCrop(args.crop_size),
1011
joint_transforms.RandomHorizontallyFlip(),
1112
]
1213
return joint_transforms.Compose(aug_list)
1314

1415

16+
def get_joint_transformations_val(args):
17+
aug_list = [
18+
joint_transforms.Resize(args),
19+
]
20+
return joint_transforms.Compose(aug_list)
21+
22+
1523
def get_standard_transformations():
1624
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1725
return standard_transforms.Compose([

data/transforms.py

+11
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,14 @@ def __call__(self, img, mask):
152152

153153
# return self.crop(*self.scale(img, mask))
154154
return img, mask
155+
156+
157+
class Resize(object):
158+
def __init__(self, args):
159+
self.h = args.setting_size[0]
160+
self.w = args.setting_size[1]
161+
162+
def __call__(self, img, mask):
163+
img, mask = img.resize((self.w, self.h), Image.BILINEAR), mask.resize((self.w, self.h), Image.NEAREST)
164+
165+
return img, mask

inference.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import torch.backends.cudnn as cudnn
3-
from data import cityscapes
43
import torch
54
from PIL import Image
65
from data.loader_tools import get_standard_transformations
@@ -9,6 +8,7 @@
98
from model.get_model import model_generation
109
from utils.engine import inference_sliding
1110
from data.cityscapes import ColorTransition
11+
from data.facade import PolygonTrans
1212
import matplotlib.pyplot as plt
1313
import numpy as np
1414
import os
@@ -30,8 +30,8 @@ def show_single(image, location=None, save=False):
3030
plt.axis('off')
3131
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
3232
plt.margins(0, 0)
33-
if save:
34-
plt.savefig("demo/" + img_name, bbox_inches='tight', pad_inches=0)
33+
# if save:
34+
# plt.savefig("demo/" + img_name, bbox_inches='tight', pad_inches=0)
3535
plt.show()
3636

3737

@@ -42,29 +42,23 @@ def main():
4242
cudnn.benchmark = True
4343
model = model_generation(args)
4444
model.to(device)
45-
checkpoint = torch.load(args.output_dir + "/48_epoch_PSPNet.pt", map_location="cuda:0")
45+
checkpoint = torch.load(args.output_dir + args.dataset + "_" + args.model_name + ".pt", map_location="cuda:1")
4646
model.load_state_dict(checkpoint, strict=True)
4747
model.eval()
4848

4949
standard_transformations = get_standard_transformations()
5050
img = Image.open(img_path).convert('RGB')
51+
img = img.resize((args.setting_size[1], args.setting_size[0]), Image.BILINEAR)
5152
img = standard_transformations(img).to(device, dtype=torch.float32)
5253
pred, full_pred = inference_sliding(args, model, img.unsqueeze(0))
53-
color_img = ColorTransition().recover(torch.squeeze(pred, dim=0))
54+
color_img = PolygonTrans().id2trainId(torch.squeeze(pred, dim=0).cpu().detach().numpy())
5455
show_single(color_img, save=True)
5556

5657

5758
if __name__ == '__main__':
5859
os.makedirs('demo/', exist_ok=True)
5960
parser = argparse.ArgumentParser('model training and evaluation script', parents=[get_args_parser()])
6061
args = parser.parse_args()
61-
if args.dataset == "cityscapes":
62-
args.num_classes = cityscapes.num_classes
63-
else:
64-
args.num_classes = 1
65-
66-
root = "/home/wangbowen/streetview/"
67-
imgs = get_name(root, mode_folder=False)
68-
for img_name in imgs:
69-
img_path = root + img_name
70-
main()
62+
args.num_classes = 10
63+
img_path = "/home/wangbowen/DATA/Facade/zhao_translated_data/images/IMG_1282.jpg"
64+
main()

0 commit comments

Comments
 (0)