Skip to content

Commit d2179ad

Browse files
committed
optimize gpu inference for tflite model
1 parent 2c11242 commit d2179ad

File tree

5 files changed

+50
-31
lines changed

5 files changed

+50
-31
lines changed

convert_tflite.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import os
99
from core.config import cfg
1010

11-
flags.DEFINE_string('weights', './checkpoints/yolov3-416', 'path to weights file')
12-
flags.DEFINE_string('output', './checkpoints/yolov3-416-int8.tflite', 'path to output')
11+
flags.DEFINE_string('weights', './checkpoints/yolov4-416', 'path to weights file')
12+
flags.DEFINE_string('output', './checkpoints/yolov4-416-fp32.tflite', 'path to output')
1313
flags.DEFINE_integer('input_size', 416, 'path to output')
1414
flags.DEFINE_string('quantize_mode', 'float32', 'quantize mode (int8, float16, float32)')
1515
flags.DEFINE_string('dataset', "/Volumes/Elements/data/coco_dataset/coco/5k.txt", 'path to dataset')

core/yolov4.py

+37-18
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,16 @@ def decode_train(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYS
192192
return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
193193

194194
def decode_tf(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1, 1, 1]):
195+
batch_size = tf.shape(conv_output)[0]
195196
conv_output = tf.reshape(conv_output,
196-
(tf.shape(conv_output)[0], output_size, output_size, 3, 5 + NUM_CLASS))
197+
(batch_size, output_size, output_size, 3, 5 + NUM_CLASS))
197198

198199
conv_raw_dxdy, conv_raw_dwdh, conv_raw_conf, conv_raw_prob = tf.split(conv_output, (2, 2, 1, NUM_CLASS),
199200
axis=-1)
200201

201202
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
202203
xy_grid = tf.expand_dims(tf.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
203-
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [tf.shape(conv_output)[0], 1, 1, 3, 1])
204+
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [batch_size, 1, 1, 3, 1])
204205

205206
xy_grid = tf.cast(xy_grid, tf.float32)
206207

@@ -213,40 +214,55 @@ def decode_tf(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCAL
213214
pred_prob = tf.sigmoid(conv_raw_prob)
214215

215216
pred_prob = pred_conf * pred_prob
217+
pred_prob = tf.reshape(pred_prob, (batch_size, -1, NUM_CLASS))
218+
pred_xywh = tf.reshape(pred_xywh, (batch_size, -1, 4))
219+
216220
return pred_xywh, pred_prob
217221
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
218222

219223
def decode_tflite(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1,1,1]):
220-
conv_output = tf.reshape(conv_output, (1, output_size, output_size, 3, 5 + NUM_CLASS))
221-
222-
conv_raw_dxdy, conv_raw_dwdh, conv_raw_conf, conv_raw_prob = tf.split(conv_output, (2, 2, 1, NUM_CLASS), axis=-1)
224+
conv_raw_dxdy_0, conv_raw_dwdh_0, conv_raw_score_0,\
225+
conv_raw_dxdy_1, conv_raw_dwdh_1, conv_raw_score_1,\
226+
conv_raw_dxdy_2, conv_raw_dwdh_2, conv_raw_score_2 = tf.split(conv_output, (2, 2, 1+NUM_CLASS, 2, 2, 1+NUM_CLASS,
227+
2, 2, 1+NUM_CLASS), axis=-1)
228+
229+
conv_raw_score = [conv_raw_score_0, conv_raw_score_1, conv_raw_score_2]
230+
for idx, score in enumerate(conv_raw_score):
231+
score = tf.sigmoid(score)
232+
score = score[:, :, :, 0:1] * score[:, :, :, 1:]
233+
conv_raw_score[idx] = tf.reshape(score, (1, -1, NUM_CLASS))
234+
pred_prob = tf.concat(conv_raw_score, axis=1)
235+
236+
conv_raw_dwdh = [conv_raw_dwdh_0, conv_raw_dwdh_1, conv_raw_dwdh_2]
237+
for idx, dwdh in enumerate(conv_raw_dwdh):
238+
dwdh = tf.exp(dwdh) * ANCHORS[i][idx]
239+
conv_raw_dwdh[idx] = tf.reshape(dwdh, (1, -1, 2))
240+
pred_wh = tf.concat(conv_raw_dwdh, axis=1)
223241

224242
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
225-
xy_grid = tf.expand_dims(tf.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
226-
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [1, 1, 1, 3, 1])
227-
243+
xy_grid = tf.stack(xy_grid, axis=-1) # [gx, gy, 2]
244+
xy_grid = tf.expand_dims(xy_grid, axis=0)
228245
xy_grid = tf.cast(xy_grid, tf.float32)
229246

230-
pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
247+
conv_raw_dxdy = [conv_raw_dxdy_0, conv_raw_dxdy_1, conv_raw_dxdy_2]
248+
for idx, dxdy in enumerate(conv_raw_dxdy):
249+
dxdy = ((tf.sigmoid(dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
231250
STRIDES[i]
232-
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
251+
conv_raw_dxdy[idx] = tf.reshape(dxdy, (1, -1, 2))
252+
pred_xy = tf.concat(conv_raw_dxdy, axis=1)
233253
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
234-
235-
pred_conf = tf.sigmoid(conv_raw_conf)
236-
pred_prob = tf.sigmoid(conv_raw_prob)
237-
238-
pred_prob = pred_conf * pred_prob
239254
return pred_xywh, pred_prob
240255
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
241256

242257
def decode_trt(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1,1,1]):
243-
conv_output = tf.reshape(conv_output, (tf.shape(conv_output)[0], output_size, output_size, 3, 5 + NUM_CLASS))
258+
batch_size = tf.shape(conv_output)[0]
259+
conv_output = tf.reshape(conv_output, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS))
244260

245261
conv_raw_dxdy, conv_raw_dwdh, conv_raw_conf, conv_raw_prob = tf.split(conv_output, (2, 2, 1, NUM_CLASS), axis=-1)
246262

247263
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
248264
xy_grid = tf.expand_dims(tf.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
249-
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [tf.shape(conv_output)[0], 1, 1, 3, 1])
265+
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [batch_size, 1, 1, 3, 1])
250266

251267
# x = tf.tile(tf.expand_dims(tf.range(output_size, dtype=tf.float32), axis=0), [output_size, 1])
252268
# y = tf.tile(tf.expand_dims(tf.range(output_size, dtype=tf.float32), axis=1), [1, output_size])
@@ -258,14 +274,17 @@ def decode_trt(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCA
258274
# pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
259275
# STRIDES[i]
260276
pred_xy = (tf.reshape(tf.sigmoid(conv_raw_dxdy), (-1, 2)) * XYSCALE[i] - 0.5 * (XYSCALE[i] - 1) + tf.reshape(xy_grid, (-1, 2))) * STRIDES[i]
261-
pred_xy = tf.reshape(pred_xy, (tf.shape(conv_output)[0], output_size, output_size, 3, 2))
277+
pred_xy = tf.reshape(pred_xy, (batch_size, output_size, output_size, 3, 2))
262278
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
263279
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
264280

265281
pred_conf = tf.sigmoid(conv_raw_conf)
266282
pred_prob = tf.sigmoid(conv_raw_prob)
267283

268284
pred_prob = pred_conf * pred_prob
285+
286+
pred_prob = tf.reshape(pred_prob, (batch_size, -1, NUM_CLASS))
287+
pred_xywh = tf.reshape(pred_xywh, (batch_size, -1, 4))
269288
return pred_xywh, pred_prob
270289
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
271290

detect.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def main(_argv):
5555
interpreter.set_tensor(input_details[0]['index'], images_data)
5656
interpreter.invoke()
5757
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
58-
if FLAGS.model == 'yolov4' and FLAGS.tiny == True:
59-
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25)
58+
if FLAGS.model == 'yolov3' and FLAGS.tiny == True:
59+
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))
6060
else:
61-
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25)
61+
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))
6262
else:
6363
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
6464
infer = saved_model_loaded.signatures['serving_default']

detectvideo.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ def main(_argv):
6363
interpreter.set_tensor(input_details[0]['index'], image_data)
6464
interpreter.invoke()
6565
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
66-
if FLAGS.model == 'yolov4' and FLAGS.tiny == True:
67-
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25)
66+
if FLAGS.model == 'yolov3' and FLAGS.tiny == True:
67+
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25,
68+
input_shape=tf.constant([input_size, input_size]))
6869
else:
69-
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25)
70+
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25,
71+
input_shape=tf.constant([input_size, input_size]))
7072
else:
7173
batch_data = tf.constant(image_data)
7274
pred_bbox = infer(batch_data)

save_model.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,12 @@ def save_tf():
3838
output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
3939
bbox_tensors.append(output_tensors[0])
4040
prob_tensors.append(output_tensors[1])
41-
pred_bbox = [tf.reshape(x, (tf.shape(x)[0], -1, tf.shape(x)[-1])) for x in bbox_tensors]
42-
pred_bbox = tf.concat(pred_bbox, axis=1)
43-
pred_prob = [tf.reshape(x, (tf.shape(x)[0], -1, tf.shape(x)[-1])) for x in prob_tensors]
44-
pred_prob = tf.concat(pred_prob, axis=1)
41+
pred_bbox = tf.concat(bbox_tensors, axis=1)
42+
pred_prob = tf.concat(prob_tensors, axis=1)
4543
if FLAGS.framework == 'tflite':
4644
pred = (pred_bbox, pred_prob)
4745
else:
48-
boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=FLAGS.score_thres)
46+
boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=FLAGS.score_thres, input_shape=tf.constant([FLAGS.input_size, FLAGS.input_size]))
4947
pred = tf.concat([boxes, pred_conf], axis=-1)
5048
model = tf.keras.Model(input_layer, pred)
5149
utils.load_weights(model, FLAGS.weights, FLAGS.model, FLAGS.tiny)

0 commit comments

Comments
 (0)