@@ -192,15 +192,16 @@ def decode_train(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYS
192
192
return tf .concat ([pred_xywh , pred_conf , pred_prob ], axis = - 1 )
193
193
194
194
def decode_tf (conv_output , output_size , NUM_CLASS , STRIDES , ANCHORS , i = 0 , XYSCALE = [1 , 1 , 1 ]):
195
+ batch_size = tf .shape (conv_output )[0 ]
195
196
conv_output = tf .reshape (conv_output ,
196
- (tf . shape ( conv_output )[ 0 ] , output_size , output_size , 3 , 5 + NUM_CLASS ))
197
+ (batch_size , output_size , output_size , 3 , 5 + NUM_CLASS ))
197
198
198
199
conv_raw_dxdy , conv_raw_dwdh , conv_raw_conf , conv_raw_prob = tf .split (conv_output , (2 , 2 , 1 , NUM_CLASS ),
199
200
axis = - 1 )
200
201
201
202
xy_grid = tf .meshgrid (tf .range (output_size ), tf .range (output_size ))
202
203
xy_grid = tf .expand_dims (tf .stack (xy_grid , axis = - 1 ), axis = 2 ) # [gx, gy, 1, 2]
203
- xy_grid = tf .tile (tf .expand_dims (xy_grid , axis = 0 ), [tf . shape ( conv_output )[ 0 ] , 1 , 1 , 3 , 1 ])
204
+ xy_grid = tf .tile (tf .expand_dims (xy_grid , axis = 0 ), [batch_size , 1 , 1 , 3 , 1 ])
204
205
205
206
xy_grid = tf .cast (xy_grid , tf .float32 )
206
207
@@ -213,40 +214,55 @@ def decode_tf(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCAL
213
214
pred_prob = tf .sigmoid (conv_raw_prob )
214
215
215
216
pred_prob = pred_conf * pred_prob
217
+ pred_prob = tf .reshape (pred_prob , (batch_size , - 1 , NUM_CLASS ))
218
+ pred_xywh = tf .reshape (pred_xywh , (batch_size , - 1 , 4 ))
219
+
216
220
return pred_xywh , pred_prob
217
221
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
218
222
219
223
def decode_tflite (conv_output , output_size , NUM_CLASS , STRIDES , ANCHORS , i = 0 , XYSCALE = [1 ,1 ,1 ]):
220
- conv_output = tf .reshape (conv_output , (1 , output_size , output_size , 3 , 5 + NUM_CLASS ))
221
-
222
- conv_raw_dxdy , conv_raw_dwdh , conv_raw_conf , conv_raw_prob = tf .split (conv_output , (2 , 2 , 1 , NUM_CLASS ), axis = - 1 )
224
+ conv_raw_dxdy_0 , conv_raw_dwdh_0 , conv_raw_score_0 ,\
225
+ conv_raw_dxdy_1 , conv_raw_dwdh_1 , conv_raw_score_1 ,\
226
+ conv_raw_dxdy_2 , conv_raw_dwdh_2 , conv_raw_score_2 = tf .split (conv_output , (2 , 2 , 1 + NUM_CLASS , 2 , 2 , 1 + NUM_CLASS ,
227
+ 2 , 2 , 1 + NUM_CLASS ), axis = - 1 )
228
+
229
+ conv_raw_score = [conv_raw_score_0 , conv_raw_score_1 , conv_raw_score_2 ]
230
+ for idx , score in enumerate (conv_raw_score ):
231
+ score = tf .sigmoid (score )
232
+ score = score [:, :, :, 0 :1 ] * score [:, :, :, 1 :]
233
+ conv_raw_score [idx ] = tf .reshape (score , (1 , - 1 , NUM_CLASS ))
234
+ pred_prob = tf .concat (conv_raw_score , axis = 1 )
235
+
236
+ conv_raw_dwdh = [conv_raw_dwdh_0 , conv_raw_dwdh_1 , conv_raw_dwdh_2 ]
237
+ for idx , dwdh in enumerate (conv_raw_dwdh ):
238
+ dwdh = tf .exp (dwdh ) * ANCHORS [i ][idx ]
239
+ conv_raw_dwdh [idx ] = tf .reshape (dwdh , (1 , - 1 , 2 ))
240
+ pred_wh = tf .concat (conv_raw_dwdh , axis = 1 )
223
241
224
242
xy_grid = tf .meshgrid (tf .range (output_size ), tf .range (output_size ))
225
- xy_grid = tf .expand_dims (tf .stack (xy_grid , axis = - 1 ), axis = 2 ) # [gx, gy, 1, 2]
226
- xy_grid = tf .tile (tf .expand_dims (xy_grid , axis = 0 ), [1 , 1 , 1 , 3 , 1 ])
227
-
243
+ xy_grid = tf .stack (xy_grid , axis = - 1 ) # [gx, gy, 2]
244
+ xy_grid = tf .expand_dims (xy_grid , axis = 0 )
228
245
xy_grid = tf .cast (xy_grid , tf .float32 )
229
246
230
- pred_xy = ((tf .sigmoid (conv_raw_dxdy ) * XYSCALE [i ]) - 0.5 * (XYSCALE [i ] - 1 ) + xy_grid ) * \
247
+ conv_raw_dxdy = [conv_raw_dxdy_0 , conv_raw_dxdy_1 , conv_raw_dxdy_2 ]
248
+ for idx , dxdy in enumerate (conv_raw_dxdy ):
249
+ dxdy = ((tf .sigmoid (dxdy ) * XYSCALE [i ]) - 0.5 * (XYSCALE [i ] - 1 ) + xy_grid ) * \
231
250
STRIDES [i ]
232
- pred_wh = (tf .exp (conv_raw_dwdh ) * ANCHORS [i ])
251
+ conv_raw_dxdy [idx ] = tf .reshape (dxdy , (1 , - 1 , 2 ))
252
+ pred_xy = tf .concat (conv_raw_dxdy , axis = 1 )
233
253
pred_xywh = tf .concat ([pred_xy , pred_wh ], axis = - 1 )
234
-
235
- pred_conf = tf .sigmoid (conv_raw_conf )
236
- pred_prob = tf .sigmoid (conv_raw_prob )
237
-
238
- pred_prob = pred_conf * pred_prob
239
254
return pred_xywh , pred_prob
240
255
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
241
256
242
257
def decode_trt (conv_output , output_size , NUM_CLASS , STRIDES , ANCHORS , i = 0 , XYSCALE = [1 ,1 ,1 ]):
243
- conv_output = tf .reshape (conv_output , (tf .shape (conv_output )[0 ], output_size , output_size , 3 , 5 + NUM_CLASS ))
258
+ batch_size = tf .shape (conv_output )[0 ]
259
+ conv_output = tf .reshape (conv_output , (batch_size , output_size , output_size , 3 , 5 + NUM_CLASS ))
244
260
245
261
conv_raw_dxdy , conv_raw_dwdh , conv_raw_conf , conv_raw_prob = tf .split (conv_output , (2 , 2 , 1 , NUM_CLASS ), axis = - 1 )
246
262
247
263
xy_grid = tf .meshgrid (tf .range (output_size ), tf .range (output_size ))
248
264
xy_grid = tf .expand_dims (tf .stack (xy_grid , axis = - 1 ), axis = 2 ) # [gx, gy, 1, 2]
249
- xy_grid = tf .tile (tf .expand_dims (xy_grid , axis = 0 ), [tf . shape ( conv_output )[ 0 ] , 1 , 1 , 3 , 1 ])
265
+ xy_grid = tf .tile (tf .expand_dims (xy_grid , axis = 0 ), [batch_size , 1 , 1 , 3 , 1 ])
250
266
251
267
# x = tf.tile(tf.expand_dims(tf.range(output_size, dtype=tf.float32), axis=0), [output_size, 1])
252
268
# y = tf.tile(tf.expand_dims(tf.range(output_size, dtype=tf.float32), axis=1), [1, output_size])
@@ -258,14 +274,17 @@ def decode_trt(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCA
258
274
# pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
259
275
# STRIDES[i]
260
276
pred_xy = (tf .reshape (tf .sigmoid (conv_raw_dxdy ), (- 1 , 2 )) * XYSCALE [i ] - 0.5 * (XYSCALE [i ] - 1 ) + tf .reshape (xy_grid , (- 1 , 2 ))) * STRIDES [i ]
261
- pred_xy = tf .reshape (pred_xy , (tf . shape ( conv_output )[ 0 ] , output_size , output_size , 3 , 2 ))
277
+ pred_xy = tf .reshape (pred_xy , (batch_size , output_size , output_size , 3 , 2 ))
262
278
pred_wh = (tf .exp (conv_raw_dwdh ) * ANCHORS [i ])
263
279
pred_xywh = tf .concat ([pred_xy , pred_wh ], axis = - 1 )
264
280
265
281
pred_conf = tf .sigmoid (conv_raw_conf )
266
282
pred_prob = tf .sigmoid (conv_raw_prob )
267
283
268
284
pred_prob = pred_conf * pred_prob
285
+
286
+ pred_prob = tf .reshape (pred_prob , (batch_size , - 1 , NUM_CLASS ))
287
+ pred_xywh = tf .reshape (pred_xywh , (batch_size , - 1 , 4 ))
269
288
return pred_xywh , pred_prob
270
289
# return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
271
290
0 commit comments