Skip to content

Commit 24cd571

Browse files
committed
onnx inference
1 parent 9abc442 commit 24cd571

File tree

1 file changed

+102
-6
lines changed

1 file changed

+102
-6
lines changed

onnx_infer.py

+102-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import sys
55
import torch
66
import numpy as np
7+
import cv2
8+
import os
9+
os.environ['KMP_DUPLICATE_LIB_OK']='True'
10+
from dataloader.data_transforms import *
11+
from util.tools import *
712

813
def parse_args():
914
parser = argparse.ArgumentParser(description="onnx_inference")
@@ -16,25 +21,116 @@ def parse_args():
1621
args = parser.parse_args()
1722
return args
1823

24+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None):
25+
"""Performs Non-Maximum Suppression (NMS) on inference results
26+
Returns:
27+
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
28+
"""
29+
30+
nc = prediction.shape[2] - 5 # number of classes
31+
32+
# Settings
33+
# (pixels) minimum and maximum box width and height
34+
max_wh = 4096
35+
max_det = 300 # maximum number of detections per image
36+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
37+
time_limit = 1.0 # seconds to quit after
38+
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
39+
40+
output = [np.zeros(6)] * prediction.shape[0]
41+
42+
for xi, x in enumerate(prediction): # image index, image inference
43+
# Apply constraints
44+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
45+
x = x[x[..., 4] > conf_thres] # confidence
46+
47+
# If none remain process next image
48+
if not x.shape[0]:
49+
continue
50+
51+
# Compute conf
52+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
53+
54+
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
55+
box = cxcy2minmax(x[:, :4])
56+
57+
# Detections matrix nx6 (xyxy, conf, cls)
58+
if multi_label:
59+
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
60+
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
61+
else: # best class only
62+
conf, j = x[:, 5:].max(1, keepdim=True)
63+
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
64+
65+
# Filter by class
66+
if classes is not None:
67+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
68+
69+
# Check shape
70+
n = x.shape[0] # number of boxes
71+
if not n: # no boxes
72+
continue
73+
elif n > max_nms: # excess boxes
74+
# sort by confidence
75+
x = x[x[:, 4].argsort(descending=True)[:max_nms]]
76+
77+
# Batched NMS
78+
c = x[:, 5:6] * max_wh # classes
79+
# boxes (offset by class), scores
80+
boxes, scores = x[:, :4] + c, x[:, 4]
81+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
82+
if i.shape[0] > max_det: # limit detections
83+
i = i[:max_det]
84+
85+
output[xi] = x[i].detach().cpu()
86+
87+
return output
88+
1989
def main():
20-
print("main")
90+
print("onnx_inference")
91+
print("onnxruntime :" , onnxruntime.get_device())
2192

2293
model = onnx.load(args.model)
2394

24-
x = torch.randn(1,3,608,608, requires_grad=True)
95+
img = cv2.imread("C:/data//kitti_dataset//testing//Images//000315.png", cv2.IMREAD_COLOR)
96+
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97+
img = cv2.resize(img, (608,608), cv2.INTER_LINEAR)
98+
#cv2.imshow("show input", img)
99+
#cv2.waitKey(0)
100+
img = np.transpose(np.array(img, dtype=np.float32) / 255, (2, 0, 1))
101+
np_img = np.expand_dims(img, axis=0)
102+
print(np_img.dtype)
103+
img = torch.FloatTensor(np.expand_dims(img, axis=0)).to(torch.device("cuda:0"))
25104

26-
print(onnx.checker.check_model(model))
27105

106+
print("input dim : ", img.shape)
107+
108+
print(onnx.checker.check_model(model))
109+
x_test = torch.randn(1, 3, 608, 608, requires_grad=True).to(torch.device("cuda:0"))
28110
def to_numpy(tensor):
29111
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
30112

31-
ort_session = onnxruntime.InferenceSession(args.model)
113+
providers = [
114+
('TensorrtExecutionProvider', {
115+
'device_id': 0,
116+
'trt_max_workspace_size': 2147483648,
117+
'trt_fp16_enable': True,
118+
}),
119+
('CUDAExecutionProvider', {
120+
'device_id': 0,
121+
'arena_extend_strategy': 'kNextPowerOfTwo',
122+
'gpu_mem_limit': 2 * 1024 * 1024 * 1024,
123+
'cudnn_conv_algo_search': 'EXHAUSTIVE',
124+
'do_copy_in_default_stream': True,
125+
})
126+
]
127+
ort_session = onnxruntime.InferenceSession(args.model,providers=providers) #, 'CPUExecutionProvider' ['TensorrtExecutionProvider', 'CUDAExecutionProvider']
32128

33129
# ONNX 런타임에서 계산된 결과값
34-
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
130+
ort_inputs = {ort_session.get_inputs()[0].name: np_img} #to_numpy(img)
35131

36132
ort_outs = ort_session.run(None, ort_inputs)
37-
print("out : ", ort_outs)
133+
print("out dim: ", ort_outs[0].shape)
38134

39135
if __name__ == "__main__":
40136
args = parse_args()

0 commit comments

Comments
 (0)