Skip to content

Commit 574e222

Browse files
author
wbw520
committed
add evaluation
1 parent 674b59a commit 574e222

10 files changed

+322
-15
lines changed

configs.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ def get_args_parser():
66

77
# train settings
88
parser.add_argument("--dataset", type=str, default="facade")
9-
parser.add_argument("--model_name", type=str, default="PSPNet")
10-
parser.add_argument("--pre_model", type=str, default="ViT-B_8.npz")
11-
parser.add_argument("--batch_size", type=int, default=4,
9+
parser.add_argument("--model_name", type=str, default="Segmenter")
10+
parser.add_argument("--pre_model", type=str, default="ViT-B_16.npz")
11+
parser.add_argument("--batch_size", type=int, default=1,
1212
help="Number of images sent to the network in one step.")
1313
parser.add_argument("--root", type=str, default="/home/wangbowen/DATA/",
1414
help="Path to the directory containing the image list.")
@@ -27,7 +27,7 @@ def get_args_parser():
2727
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay.")
2828

2929
# VIT settings
30-
parser.add_argument("--encoder", type=str, default="vit_base_patch8", help="name for encoder")
30+
parser.add_argument("--encoder", type=str, default="vit_base_patch16", help="name for encoder")
3131
parser.add_argument("--decoder_embed_dim", type=int, default=512, help="dimension for decoder.")
3232
parser.add_argument("--decoder_depth", type=int, default=2, help="depth for decoder.")
3333
parser.add_argument("--decoder_num_head", type=int, default=8, help="head number for decoder.")

data/facade.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,17 @@ def polygon2mask(self, img_size, polygons, rectangles):
3131
return mask
3232

3333
# translate label_id to color img
34-
def id2trainId(self, label):
34+
def id2trainId(self, label, select=None):
3535
w, h = label.shape
3636
label_copy = np.zeros((w, h, 3), dtype=np.uint8)
3737
for index, color in colors.items():
38-
label_copy[label == index] = color
38+
if select is not None:
39+
if index == select:
40+
label_copy[label == index] = color
41+
else:
42+
continue
43+
else:
44+
label_copy[label == index] = color
3945
return label_copy.astype(np.uint8)
4046

4147

data/get_data_set.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_data(args, evaluation_setting=None):
2424
else:
2525
current_set = "val"
2626

27-
val_set = facade.Facade(args, 'test', joint_transform=joint_transformations_val,
27+
val_set = facade.Facade(args, current_set, joint_transform=joint_transformations_val,
2828
standard_transform=standard_transformations)
2929
ignore_index = facade.ignore_label
3030
args.num_classes = facade.num_classes

evaluation.py

-1
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,4 @@ def main():
3838
os.makedirs('demo/', exist_ok=True)
3939
parser = argparse.ArgumentParser('model training and evaluation script', parents=[get_args_parser()])
4040
args = parser.parse_args()
41-
img_path = "/home/wangbowen/DATA/Facade/translated_data/images/IMG_1287.png"
4241
main()

inference.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
import os
1616

1717

18-
def show_single(image, location=None, save=False):
18+
def show_single(image, location=None, save=False, name=None):
1919
# show single image
2020
image = np.array(image, dtype=np.uint8)
2121
plt.imshow(image)
2222

2323
plt.axis('off')
2424
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
2525
plt.margins(0, 0)
26-
# if save:
27-
# plt.savefig("demo/" + img_name, bbox_inches='tight', pad_inches=0)
26+
if save:
27+
plt.savefig(name, bbox_inches='tight', pad_inches=0)
2828
plt.show()
2929

3030

@@ -48,16 +48,18 @@ def main():
4848

4949
standard_transformations = get_standard_transformations()
5050
img = Image.open(img_path).convert('RGB')
51+
5152
img = img.resize((args.setting_size[1], args.setting_size[0]), Image.BILINEAR)
5253
img = standard_transformations(img).to(device, dtype=torch.float32)
5354
pred, full_pred = inference_sliding(args, model, img.unsqueeze(0))
54-
color_img = PolygonTrans().id2trainId(torch.squeeze(pred, dim=0).cpu().detach().numpy())
55-
show_single(color_img, save=True)
55+
color_img = PolygonTrans().id2trainId(torch.squeeze(pred, dim=0).cpu().detach().numpy(), select=2)
56+
print(color_img.shape)
57+
show_single(color_img, save=True, name="color_mask2.png")
5658

5759

5860
if __name__ == '__main__':
5961
os.makedirs('demo/', exist_ok=True)
6062
parser = argparse.ArgumentParser('model training and evaluation script', parents=[get_args_parser()])
6163
args = parser.parse_args()
62-
img_path = "/home/wangbowen/DATA/Facade/translated_data/images/IMG_1287.png"
64+
img_path = "/home/wangbowen/DATA/Facade/translated_data/images/32052284_477d66a5ae_o.png"
6365
main()

line_detection/connect_components.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import cv2
2+
import numpy as np
3+
from inference import show_single
4+
5+
6+
img = cv2.imread("../demo/test.png")
7+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
8+
ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
9+
10+
kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
11+
bin_clo = cv2.erode(binary, kernel2, iterations=1)
12+
13+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_clo, connectivity=8)
14+
15+
# print('num_labels = ', num_labels)
16+
# print('labels = ', labels)
17+
# # 不同的连通域赋予不同的颜色
18+
output = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
19+
20+
for i in range(1, num_labels):
21+
mask = labels == i
22+
# 连通域的信息:对应各个轮廓的x、y、width、height和面积
23+
print('stats = ', stats[i])
24+
if stats[i][4] < 30:
25+
continue
26+
# 连通域的中心点
27+
print('centroids = ', centroids[i])
28+
output[:, :, 0][mask] = np.random.randint(0, 255)
29+
output[:, :, 1][mask] = np.random.randint(0, 255)
30+
output[:, :, 2][mask] = np.random.randint(0, 255)
31+
break
32+
33+
show_single(output)

line_detection/hough.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import cv2
2+
import numpy as np
3+
from inference import show_single
4+
5+
src = cv2.imread("/home/wangbowen/DATA/Facade/translated_data/images/IMG_E1283.png")
6+
# src = cv2.imread("../demo/test.png")
7+
gray_img = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
8+
9+
10+
dst = cv2.equalizeHist(gray_img)
11+
# 高斯滤波降噪
12+
gaussian = cv2.GaussianBlur(dst, (9, 9), 0)
13+
# cv.imshow("gaussian", gaussian)
14+
15+
# 边缘检测
16+
edges = cv2.Canny(gaussian, 70, 150)
17+
show_single(edges)
18+
19+
# Hough 直线检测
20+
# 重点注意第四个参数 阈值,只有累加后的值高于阈值时才被认为是一条直线,也可以把它看成能检测到的直线的最短长度(以像素点为单位)
21+
# 在霍夫空间理解为:至少有多少条正弦曲线交于一点才被认为是直线
22+
lines = cv2.HoughLinesP(edges, 1, 1 * np.pi / 180, 10, minLineLength=10, maxLineGap=5)#统计概率霍夫线变换函数:图像矩阵,极坐标两个参数,一条直线所需最少的曲线交点,组成一条直线的最少点的数量,被认为在一条直线上的亮点的最大距离
23+
print("Line Num : ", len(lines))
24+
25+
# 画出检测的线段
26+
for line in lines:
27+
for x1, y1, x2, y2 in line:
28+
cv2.line(src, (x1, y1), (x2, y2), (255, 0, 0), 2)
29+
pass
30+
31+
show_single(src)

line_detection/line_revision.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import cv2
2+
import numpy as np
3+
import math
4+
from inference import show_single
5+
from shapely.geometry import LineString
6+
7+
8+
def lsd():
9+
src = cv2.imread("/home/wangbowen/DATA/Facade/translated_data/images/32052284_477d66a5ae_o.png")
10+
gray = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
11+
src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
12+
gray = cv2.GaussianBlur(gray, (5, 5), 5)
13+
gray = cv2.GaussianBlur(gray, (3, 3), 5)
14+
15+
LSD = cv2.createLineSegmentDetector(0)
16+
dlines = LSD.detect(gray)
17+
18+
line_record = []
19+
20+
for dline in dlines[0]:
21+
x0 = int(round(dline[0][0]))
22+
y0 = int(round(dline[0][1]))
23+
x1 = int(round(dline[0][2]))
24+
y1 = int(round(dline[0][3]))
25+
line_record.append([x0, y0, x1, y1])
26+
27+
return line_record, src
28+
29+
30+
def calc_abc_from_line_2d(x0, y0, x1, y1):
31+
a = y0-y1
32+
b = x1-x0
33+
c = x0*y1-x1*y0
34+
return a, b, c
35+
36+
37+
def get_line_cross_point(line1, line2):
38+
a0, b0, c0 = calc_abc_from_line_2d(*line1)
39+
a1, b1, c1 = calc_abc_from_line_2d(*line2)
40+
D = a0*b1-a1*b0
41+
if D == 0:
42+
return None
43+
x = (b0*c1-b1*c0)/D
44+
y = (a1*c0-a0*c1)/D
45+
return x, y
46+
47+
48+
def combine(lines):
49+
def get_line(j):
50+
if index[j] == 1:
51+
current_line = [lines[j][2], lines[j][3], lines[j][4], lines[j][5]]
52+
else:
53+
if j == 0 or j == 2:
54+
current_line = [5, lines[j][0], 20, lines[j][0]]
55+
else:
56+
current_line = [lines[j][0], 20, lines[j][0], 40]
57+
return current_line
58+
59+
index = []
60+
for i in range(len(lines)):
61+
if len(lines[i]) == 1:
62+
index.append(0)
63+
else:
64+
index.append(1)
65+
66+
if np.array(index).sum() < 2:
67+
return None
68+
69+
cross_record = []
70+
start_line = None
71+
for s in range(len(index)):
72+
p_current_line = get_line(s)
73+
if s == 0:
74+
start_line = p_current_line
75+
76+
if s == 3:
77+
p_next_line = start_line
78+
else:
79+
p_next_line = get_line(s+1)
80+
cross_point = get_line_cross_point(p_current_line, p_next_line)
81+
cross_record.append(cross_point)
82+
return cross_record
83+
84+
85+
def line_search(enhance_stats, lines):
86+
x, y, w, h = enhance_stats
87+
top = LineString([(x, y), (x + w, y)])
88+
bottom = LineString([(x, y + h), (x + w, y + h)])
89+
left = LineString([(x, y), (x, y + h)])
90+
right = LineString([(x + w, y), (x + w, y + h)])
91+
92+
line_list = {"top": top, "left": left, "bottom": bottom, "right": right}
93+
distance_thresh = 10
94+
degree_thresh = 0.1
95+
max_selection = 1
96+
record = {"top": [], "left": [], "bottom": [], "right": []}
97+
98+
for key, value in line_list.items():
99+
# print(key)
100+
for (x0, y0, x1, y1) in lines:
101+
current_line = LineString([(x0, y0), (x1, y1)])
102+
current_degree = math.atan2(y0 - y1, x0 - x1)
103+
current_dis = value.distance(current_line)
104+
line_len = (x0 - x1)**2 + (y0 - y1)**2
105+
106+
if current_dis > distance_thresh:
107+
continue
108+
109+
# print([current_dis, current_degree, x0, y0, x1, y1])
110+
111+
if key == "top" or key == "bottom":
112+
if math.pi * 1/5 < abs(current_degree) < math.pi * 4/5:
113+
continue
114+
if line_len > w**2 * 1.5 or line_len < w**2 / 3:
115+
continue
116+
else:
117+
if abs(current_degree) < math.pi * 1/3 or abs(current_degree) > math.pi * 2/3:
118+
continue
119+
if line_len > h**2 * 1.5 or line_len < h**2 / 3:
120+
continue
121+
122+
status = True
123+
for i in range(len(record[key])):
124+
if abs(abs(abs(current_degree) - math.pi/2) - abs(abs(record[key][i][1]) - math.pi/2)) < degree_thresh:
125+
if record[key][i][0] > current_dis:
126+
record[key][i] = [current_dis, current_degree, x0, y0, x1, y1]
127+
status = False
128+
129+
if status:
130+
record[key].append([current_dis, current_degree, x0, y0, x1, y1])
131+
132+
final_line = []
133+
for key2, value2 in record.items():
134+
value2.sort(key=lambda s: s[0], reverse=False)
135+
num = min(len(value2), max_selection)
136+
if num == 0:
137+
if key2 == "top":
138+
final_line.append([y])
139+
elif key2 == "bottom":
140+
final_line.append([y + h])
141+
elif key2 == "left":
142+
final_line.append([x])
143+
else:
144+
final_line.append([x + w])
145+
continue
146+
147+
for j in range(num):
148+
final_line.append(value2[j])
149+
150+
return final_line
151+
152+
153+
def revision():
154+
img = cv2.imread("../demo/test.png")
155+
img = cv2.resize(img, (2048, 1152))
156+
img_orl = img
157+
158+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
159+
ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
160+
kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
161+
bin_clo = cv2.erode(binary, kernel1, iterations=1)
162+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_clo, connectivity=8)
163+
164+
# print('num_labels = ', num_labels)
165+
# print('labels = ', labels)
166+
# # 不同的连通域赋予不同的颜色
167+
lines, scr = lsd()
168+
169+
for i in range(1, num_labels):
170+
# if i < 10:
171+
# continue
172+
mask = labels == i
173+
# 连通域的信息:对应各个轮廓的x、y、width、height和面积
174+
if stats[i][4] < 100:
175+
continue
176+
# # 连通域的信息:对应各个轮廓的x、y、width、height和面积
177+
# print('stats = ', stats[i])
178+
# # 连通域的中心点
179+
# print('centroids = ', centroids[i])
180+
181+
current_patch = np.zeros((img.shape[0], img.shape[1]), np.uint8)
182+
current_patch[mask] = 255
183+
# show_single(current_patch)
184+
185+
# kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
186+
# current_patch = cv2.dilate(current_patch, kernel2, iterations=1)
187+
# show_single(current_patch)
188+
189+
x, y, w, h = cv2.boundingRect(current_patch)
190+
# print(x, y, w, h)
191+
# cv2.rectangle(current_patch, (x, y), (x + w, y + h), (225, 0, 255), 2)
192+
# show_single(current_patch)
193+
194+
detect_lines = line_search([x, y, w, h], lines)
195+
final_point = combine(detect_lines)
196+
197+
if final_point is None:
198+
continue
199+
200+
start_x, start_y = None, None
201+
for w in range(len(final_point)):
202+
x0, y0 = round(final_point[w][0]), round(final_point[w][1])
203+
if w == 0:
204+
start_x, start_y = x0, y0
205+
if w == 3:
206+
x1, y1 = start_x, start_y
207+
else:
208+
x1, y1 = round(final_point[w+1][0]), round(final_point[w+1][1])
209+
cv2.line(scr, (x0, y0), (x1, y1), 255, 2, cv2.LINE_AA)
210+
211+
show_single(scr, save=True, name="lines_detect.png")
212+
213+
214+
if __name__ == '__main__':
215+
revision()

0 commit comments

Comments
 (0)