Skip to content

Commit 730523a

Browse files
committed
soft nms support
1 parent 1f6e694 commit 730523a

File tree

3 files changed

+110
-12
lines changed

3 files changed

+110
-12
lines changed

yolo3/detect/img_detect.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
from yolo3.dataset.dataset import pad_to_square, resize
1717
from yolo3.utils.helper import load_classes
18-
from yolo3.utils.model_build import non_max_suppression, rescale_boxes, xywh2p1p2, resize_boxes
18+
from yolo3.utils.model_build import non_max_suppression, rescale_boxes, xywh2p1p2, resize_boxes, \
19+
soft_non_max_suppression
1920

2021

2122
def scale(image, shape, max_size):
@@ -83,7 +84,7 @@ def detect(self, img):
8384
prev_time = time.time()
8485
with torch.no_grad():
8586
detections = self.model(image)
86-
detections = non_max_suppression(detections, self.thres, self.nms_thres)
87+
detections = soft_non_max_suppression(detections, self.thres, self.nms_thres)
8788
detections = detections[0]
8889
if detections is not None:
8990
# detections = rescale_boxes(detections, self.model.img_size, (h, w))
@@ -140,7 +141,9 @@ def detect(self, img):
140141
# (1, n, 5 + num_class)
141142
rescaled_detections = torch.cat(rescaled_detections, 0).unsqueeze(0)
142143

143-
detections = non_max_suppression(rescaled_detections, self.thres, self.nms_thres, is_p1p2=True)
144+
detections = soft_non_max_suppression(rescaled_detections, self.thres, self.nms_thres,
145+
merge=True,
146+
is_p1p2=True)
144147
detections = detections[0]
145148

146149
current_time = time.time()

yolo3/utils/label_draw.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import numpy as np
55
import time
66

7+
78
def _get_statistic_info(detections, unique_labels, classes):
89
"""获得统计信息"""
910
statistic_info = {}
1011
for label in unique_labels:
1112
statistic_info[classes[int(label)]] = (
12-
detections[:, -1] == label).sum().item()
13+
detections[:, -1] == label).sum().item()
1314
return statistic_info
1415

1516

@@ -52,12 +53,12 @@ def draw_rects_and_labels(img, dets, colors, labels, thickness, font_size, font=
5253
text_size, _ = cv2.getTextSize(
5354
labels[i], cv2.FONT_HERSHEY_COMPLEX_SMALL, font_size, 1)
5455
font_w, font_h = text_size
55-
cv2.rectangle(img, (c1[0], max(0, int( c1[1] - 3 - 18 * font_size))),
56+
cv2.rectangle(img, (c1[0], max(0, int(c1[1] - 3 - 18 * font_size))),
5657
(c1[0] + font_w, max(c1[1], int(3 + 18 * font_size))), colors[cls], -1)
5758
cv2.putText(img,
5859
labels[i],
5960
(c1[0], max(c1[1] - 3, font_h)), cv2.FONT_HERSHEY_COMPLEX_SMALL, font_size,
60-
(0, 0, 0), 2)
61+
(0, 0, 0), 1)
6162
return img
6263

6364

@@ -79,14 +80,18 @@ def draw_single_img(img, detections, img_size,
7980
unique_labels = detections[:, -1].unique()
8081
statistic_info = _get_statistic_info(
8182
detections, unique_labels, classes)
82-
83+
8384
if only_rect:
8485
draw_rects(img, detections, colors, thickness)
8586
else:
8687
labels = []
8788
for detection in detections:
88-
labels.append(classes[int(detection[-1])] +
89-
' (' + str(round(detection[-3] * detection[-2] * 100, 2)) + '%)')
89+
if len(detection) == 7:
90+
labels.append(classes[int(detection[-1])] +
91+
' (' + str(round(detection[-3] * detection[-2] * 100, 2)) + '%)')
92+
else:
93+
labels.append(classes[int(detection[-1])] +
94+
' (' + str(round(detection[-2] * 100, 2)) + '%)')
9095
draw_rects_and_labels(img, detections, colors,
9196
labels, thickness, font_size, font)
9297

@@ -168,10 +173,10 @@ def draw_labels_by_trackers(self, img, detections, only_rect):
168173
for detection in detections:
169174
if self.id2label is not None and str(int(detection[4])) in self.id2label:
170175
label = str(int(detection[4])) + ":" + \
171-
self.id2label[str(int(detection[4]))]
176+
self.id2label[str(int(detection[4]))]
172177
else:
173178
label = str(int(detection[4])) + ":" + \
174-
self.classes[int(detection[-1])]
179+
self.classes[int(detection[-1])]
175180
labels.append(label)
176181

177182
# 绘制所有标签

yolo3/utils/model_build.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import time
2+
13
import torch
24
import numpy as np
5+
import torchvision
36
import tqdm
47
from torchvision.ops.boxes import batched_nms
58

@@ -46,6 +49,94 @@ def rescale_boxes(boxes, current_dim, original_shape):
4649
return boxes
4750

4851

52+
def soft_non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False,
53+
is_p1p2=False):
54+
"""Performs Non-Maximum Suppression (NMS) on inference results
55+
56+
Returns:
57+
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
58+
"""
59+
if prediction.dtype is torch.float16:
60+
prediction = prediction.float() # to FP32
61+
62+
nc = prediction[0].shape[1] - 5 # number of classes
63+
xc = prediction[..., 4] > conf_thres # candidates
64+
65+
# Settings
66+
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
67+
max_det = 300 # maximum number of detections per image
68+
time_limit = 10.0 # seconds to quit after
69+
redundant = True # require redundant detections
70+
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
71+
72+
t = time.time()
73+
output = [None] * prediction.shape[0]
74+
for xi, x in enumerate(prediction): # image index, image inference
75+
# Apply constraints
76+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
77+
x = x[xc[xi]] # confidence
78+
79+
# If none remain process next image
80+
if not x.shape[0]:
81+
continue
82+
83+
# Compute conf
84+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
85+
86+
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
87+
if not is_p1p2:
88+
box = xywh2p1p2(x[:, :4])
89+
else:
90+
box = x[:, :4]
91+
92+
# Detections matrix nx6 (xyxy, conf, cls)
93+
if multi_label:
94+
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
95+
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
96+
else: # best class only
97+
conf, j = x[:, 5:].max(1, keepdim=True)
98+
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
99+
100+
# Filter by class
101+
if classes:
102+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
103+
104+
# Apply finite constraint
105+
# if not torch.isfinite(x).all():
106+
# x = x[torch.isfinite(x).all(1)]
107+
108+
# If none remain process next image
109+
n = x.shape[0] # number of boxes
110+
if not n:
111+
continue
112+
113+
# Sort by confidence
114+
# x = x[x[:, 4].argsort(descending=True)]
115+
116+
# Batched NMS
117+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
118+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
119+
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
120+
if i.shape[0] > max_det: # limit detections
121+
i = i[:max_det]
122+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
123+
try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
124+
iou = bbox_iou(boxes[i], boxes) > iou_thres # iou matrix
125+
weights = iou * scores[None] # box weights
126+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
127+
if redundant:
128+
i = i[iou.sum(1) > 1] # require redundancy
129+
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
130+
print(x, i, x.shape, i.shape)
131+
pass
132+
133+
output[xi] = x[i]
134+
if (time.time() - t) > time_limit:
135+
break # time limit exceeded
136+
137+
return output
138+
139+
49140
def non_max_suppression(prediction, thres=0.5, nms_thres=0.4, is_p1p2=False):
50141
"""
51142
Removes detections with lower object confidence score than 'conf_thres' and performs
@@ -73,7 +164,6 @@ def non_max_suppression(prediction, thres=0.5, nms_thres=0.4, is_p1p2=False):
73164
# If none anchor are remaining => process next image
74165
if not image_pred.size(0):
75166
continue
76-
77167

78168
detections = torch.cat((image_pred[:, :5],
79169
class_confs.type(prediction.dtype),

0 commit comments

Comments
 (0)