Skip to content

Commit 245e339

Browse files
committed
🎉 support for end2end
1 parent 08bb095 commit 245e339

File tree

5 files changed

+110
-18
lines changed

5 files changed

+110
-18
lines changed

README.md

+15-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
## [简体中文](README_CN.md)
33

44
## Support
5-
YOLOv7、YOLOv6、 YOLOX、 YOLOV5、
5+
YOLOv7、YOLOv6、 YOLOX、 YOLOV5
6+
7+
The C++ code for YOLOv7/YOLOv6 also can be used for YOLOx or YOLOv5
68

79
## Update
10+
- 2022.8.11 nms plugin support ==> more simple
811
- 2022.7.8 support YOLOV7
912
- 2022.7.3 support TRT int8 post-training quantization
1013

@@ -48,15 +51,23 @@ python models/export.py --weights ../yolov7.pt --grid
4851

4952
```
5053
python export.py -o onnx-name -e trt-name -p fp32/16/int8
54+
55+
--end2end export the model include nms plugin
56+
5157
```
5258
### Test
5359

5460
```
5561
cd yolov7
5662
python trt.py
5763
```
64+
tips!
5865

59-
### C++
66+
if you use the end2end model please modift the code as such
67+
68+
`origin_img = pred.inference(img_path, conf=0.5, end2end=True)`
69+
70+
### C++ [Now don't support end2end model]
6071

6172
C++ [Demo](yolov7/cpp/README.md)
6273

@@ -84,7 +95,7 @@ python deploy/ONNX/export_onnx.py --weights yolov6s.pt --img 640 --batch 1
8495
### Convert to TensorRT Engine
8596

8697
```
87-
python export.py -o onnx-name -e trt-name -p fp32/16/int8
98+
python export.py -o onnx-name -e trt-name -p fp32/16/int8 --end2end
8899
```
89100
### Test
90101

@@ -93,7 +104,7 @@ cd yolov6
93104
python trt.py
94105
```
95106

96-
### C++
107+
### C++ [Now don't support end2end model]
97108

98109
C++ [Demo](yolov6/cpp/README.md)
99110

README_CN.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
YOLOv7、YOLOv6、 YOLOX、 YOLOV5、
55

66
## 更新
7+
- 2022.8.11 端到端导出支持, 更简洁的端到端导出方法
78
- 2022.7.8 支持YOLOV7
89
- 2022.7.3 支持 TRT int8 post-training quantization
910

export.py

+70-3
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, verbose=False, workspace=8):
112112
self.network = None
113113
self.parser = None
114114

115-
def create_network(self, onnx_path):
115+
def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det):
116116
"""
117117
Parse the ONNX graph and create the corresponding TensorRT network definition.
118118
:param onnx_path: The path to the ONNX graph to load.
@@ -142,6 +142,61 @@ def create_network(self, onnx_path):
142142
assert self.batch_size > 0
143143
self.builder.max_batch_size = self.batch_size
144144

145+
if end2end:
146+
previous_output = self.network.get_output(0)
147+
self.network.unmark_output(previous_output)
148+
# output [1, 8400, 85]
149+
# slice boxes, obj_score, class_scores
150+
strides = trt.Dims([1,1,1])
151+
starts = trt.Dims([0,0,0])
152+
bs, num_boxes, temp = previous_output.shape
153+
shapes = trt.Dims([bs, num_boxes, 4])
154+
# [0, 0, 0] [1, 8400, 4] [1, 1, 1]
155+
boxes = self.network.add_slice(previous_output, starts, shapes, strides)
156+
num_classes = temp -5
157+
starts[2] = 4
158+
shapes[2] = 1
159+
# [0, 0, 4] [1, 8400, 1] [1, 1, 1]
160+
obj_score = self.network.add_slice(previous_output, starts, shapes, strides)
161+
starts[2] = 5
162+
shapes[2] = num_classes
163+
# [0, 0, 5] [1, 8400, 80] [1, 1, 1]
164+
scores = self.network.add_slice(previous_output, starts, shapes, strides)
165+
# scores = obj_score * class_scores => [bs, num_boxes, nc]
166+
updated_scores = self.network.add_elementwise(obj_score.get_output(0), scores.get_output(0), trt.ElementWiseOperation.PROD)
167+
168+
'''
169+
"plugin_version": "1",
170+
"background_class": -1, # no background class
171+
"max_output_boxes": detections_per_img,
172+
"score_threshold": score_thresh,
173+
"iou_threshold": nms_thresh,
174+
"score_activation": False,
175+
"box_coding": 1,
176+
'''
177+
registry = trt.get_plugin_registry()
178+
assert(registry)
179+
creator = registry.get_plugin_creator("EfficientNMS_TRT", "1")
180+
assert(creator)
181+
fc = []
182+
fc.append(trt.PluginField("background_class", np.array([-1], dtype=np.int32), trt.PluginFieldType.INT32))
183+
fc.append(trt.PluginField("max_output_boxes", np.array([max_det], dtype=np.int32), trt.PluginFieldType.INT32))
184+
fc.append(trt.PluginField("score_threshold", np.array([conf_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32))
185+
fc.append(trt.PluginField("iou_threshold", np.array([iou_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32))
186+
fc.append(trt.PluginField("box_coding", np.array([1], dtype=np.int32), trt.PluginFieldType.INT32))
187+
188+
fc = trt.PluginFieldCollection(fc)
189+
nms_layer = creator.create_plugin("nms_layer", fc)
190+
191+
layer = self.network.add_plugin_v2([boxes.get_output(0), updated_scores.get_output(0)], nms_layer)
192+
layer.get_output(0).name = "num"
193+
layer.get_output(1).name = "boxes"
194+
layer.get_output(2).name = "scores"
195+
layer.get_output(3).name = "classes"
196+
for i in range(4):
197+
self.network.mark_output(layer.get_output(i))
198+
199+
145200
def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=5000,
146201
calib_batch_size=8):
147202
"""
@@ -176,7 +231,8 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
176231
# Also enable fp16, as some layers may be even more efficient in fp16 than int8
177232
self.config.set_flag(trt.BuilderFlag.FP16)
178233
self.config.set_flag(trt.BuilderFlag.INT8)
179-
self.config.int8_calibrator = EngineCalibrator(calib_cache)
234+
# self.config.int8_calibrator = EngineCalibrator(calib_cache)
235+
self.config.int8_calibrator = SwinCalibrator(calib_cache)
180236
if not os.path.exists(calib_cache):
181237
calib_shape = [calib_batch_size] + list(inputs[0].shape[1:])
182238
calib_dtype = trt.nptype(inputs[0].dtype)
@@ -190,7 +246,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
190246

191247
def main(args):
192248
builder = EngineBuilder(args.verbose, args.workspace)
193-
builder.create_network(args.onnx)
249+
builder.create_network(args.onnx, args.end2end, args.conf_thres, args.iou_thres, args.max_det)
194250
builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images,
195251
args.calib_batch_size)
196252

@@ -210,7 +266,17 @@ def main(args):
210266
help="The maximum number of images to use for calibration, default: 5000")
211267
parser.add_argument("--calib_batch_size", default=8, type=int,
212268
help="The batch size for the calibration process, default: 8")
269+
parser.add_argument("--end2end", default=False, action="store_true",
270+
help="export the engine include nms plugin, default: False")
271+
parser.add_argument("--conf_thres", default=0.4, type=float,
272+
help="The conf threshold for the nms, default: 0.4")
273+
parser.add_argument("--iou_thres", default=0.5, type=float,
274+
help="The iou threshold for the nms, default: 0.5")
275+
parser.add_argument("--max_det", default=100, type=int,
276+
help="The total num for results, default: 100")
277+
213278
args = parser.parse_args()
279+
print(args)
214280
if not all([args.onnx, args.engine]):
215281
parser.print_help()
216282
log.error("These arguments are required: --onnx and --engine")
@@ -219,6 +285,7 @@ def main(args):
219285
parser.print_help()
220286
log.error("When building in int8 precision, --calib_input or an existing --calib_cache file is required")
221287
sys.exit(1)
288+
222289
main(args)
223290

224291

utils/utils.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self, engine_path, imgsz=(640,640)):
2323

2424
logger = trt.Logger(trt.Logger.WARNING)
2525
runtime = trt.Runtime(logger)
26+
trt.init_libnvinfer_plugins(logger,'') # initialize TensorRT plugins
2627
with open(engine_path, "rb") as f:
2728
serialized_engine = f.read()
2829
engine = runtime.deserialize_cuda_engine(serialized_engine)
@@ -59,33 +60,45 @@ def infer(self, img):
5960
data = [out['host'] for out in self.outputs]
6061
return data
6162

62-
def detect_video(self, video_path):
63+
def detect_video(self, video_path, conf=0.5, end2end=False):
6364
cap = cv2.VideoCapture(video_path)
6465
while True:
6566
ret, frame = cap.read()
6667
if not ret:
6768
break
6869
blob, ratio = preproc(frame, self.imgsz, self.mean, self.std)
6970
data = self.infer(blob)
70-
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
71-
dets = self.postprocess(predictions,ratio)
71+
if end2end:
72+
num, final_boxes, final_scores, final_cls_inds = data
73+
final_boxes = np.reshape(final_boxes/ratio, (-1, 4))
74+
dets = np.concatenate([final_boxes[:num[0]], np.array(final_scores)[:num[0]].reshape(-1, 1), np.array(final_cls_inds)[:num[0]].reshape(-1, 1)], axis=-1)
75+
else:
76+
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
77+
dets = self.postprocess(predictions,ratio)
78+
7279
if dets is not None:
7380
final_boxes, final_scores, final_cls_inds = dets[:,
7481
:4], dets[:, 4], dets[:, 5]
7582
frame = vis(frame, final_boxes, final_scores, final_cls_inds,
76-
conf=0.5, class_names=self.class_names)
77-
cv2.imshow('frame', frame)
83+
conf=conf, class_names=self.class_names)
84+
cv2.imshow('frame', frame)
7885
if cv2.waitKey(25) & 0xFF == ord('q'):
7986
break
8087
cap.release()
8188
cv2.destroyAllWindows()
8289

83-
def inference(self, img_path, conf=0.5):
90+
def inference(self, img_path, conf=0.5, end2end=False):
8491
origin_img = cv2.imread(img_path)
8592
img, ratio = preproc(origin_img, self.imgsz, self.mean, self.std)
8693
data = self.infer(img)
87-
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
88-
dets = self.postprocess(predictions,ratio)
94+
if end2end:
95+
num, final_boxes, final_scores, final_cls_inds = data
96+
final_boxes = np.reshape(final_boxes/ratio, (-1, 4))
97+
dets = np.concatenate([final_boxes[:num[0]], np.array(final_scores)[:num[0]].reshape(-1, 1), np.array(final_cls_inds)[:num[0]].reshape(-1, 1)], axis=-1)
98+
else:
99+
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
100+
dets = self.postprocess(predictions,ratio)
101+
89102
if dets is not None:
90103
final_boxes, final_scores, final_cls_inds = dets[:,
91104
:4], dets[:, 4], dets[:, 5]

yolov6/trt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def __init__(self, engine_path , imgsz=(640,640)):
2525

2626

2727
if __name__ == '__main__':
28-
pred = Predictor(engine_path='yolov6.trt')
28+
pred = Predictor(engine_path='yolov6-new.trt')
2929
img_path = '../src/3.jpg'
30-
origin_img = pred.inference(img_path)
30+
origin_img = pred.inference(img_path, conf=0.5, end2end=True)
3131
cv2.imwrite("%s_yolov6.jpg" % os.path.splitext(
3232
os.path.split(img_path)[-1])[0], origin_img)
33-
pred.detect_video('../src/video1.mp4') # set 0 use a webcam
33+
pred.detect_video('../src/video1.mp4', conf=0.5, end2end=False) # set 0 use a webcam
3434
pred.get_fps()

0 commit comments

Comments
 (0)