Skip to content

Commit 4246764

Browse files
committed
update ciou loss and xyscale
1 parent 2633f44 commit 4246764

8 files changed

+113
-44
lines changed

convert_tflite.py

-17
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,6 @@ def representative_data_gen():
2727
else:
2828
continue
2929

30-
# def apply_quantization_to_dense(layer):
31-
# # print(layer.name)
32-
# if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.BatchNormalization,
33-
# tf.keras.layers.ZeroPadding2D, tf.keras.layers.ReLU)):
34-
# print(layer.name)
35-
# return tfmot.quantization.keras.quantize_annotate_layer(layer)
36-
# return layer
37-
3830
def save_tflite():
3931
input_layer = tf.keras.layers.Input([FLAGS.input_size, FLAGS.input_size, 3])
4032
if FLAGS.tiny:
@@ -58,15 +50,6 @@ def save_tflite():
5850
model.summary()
5951
utils.load_weights(model, FLAGS.weights)
6052

61-
62-
63-
# annotated_model = tf.keras.models.clone_model(
64-
# model,
65-
# clone_function=apply_quantization_to_dense,
66-
# )
67-
# quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
68-
# quant_aware_model.summary()
69-
7053
converter = tf.lite.TFLiteConverter.from_keras_model(model)
7154
if FLAGS.quantize_mode == 'int8':
7255
converter.optimizations = [tf.lite.Optimize.DEFAULT]

core/common.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def convolutional(input_layer, filters_shape, downsample=False, activate=True, b
3838
conv = tf.nn.leaky_relu(conv, alpha=0.1)
3939
elif activate_type == "mish":
4040
conv = mish(conv)
41+
# conv = softplus(conv)
42+
# conv = conv * tf.math.tanh(tf.math.softplus(conv))
43+
# conv = conv * tf.tanh(softplus(conv))
4144
# conv = tf.nn.leaky_relu(conv, alpha=0.1)
4245
# conv = tfa.activations.mish(conv)
4346
# conv = conv * tf.nn.tanh(tf.keras.activations.relu(tf.nn.softplus(conv), max_value=20))
@@ -46,9 +49,22 @@ def convolutional(input_layer, filters_shape, downsample=False, activate=True, b
4649
# if activate == True: conv = tf.keras.layers.ReLU()(conv)
4750

4851
return conv
49-
52+
def softplus(x, threshold = 20.):
53+
def f1():
54+
return x
55+
def f2():
56+
return tf.exp(x)
57+
def f3():
58+
return tf.math.log(1 + tf.exp(x))
59+
# mask = tf.greater(x, threshold)
60+
# x = tf.exp(x[mask])
61+
# return tf.exp(x)
62+
return tf.case([(tf.greater(x, tf.constant(threshold)), lambda:f1()), (tf.less(x, tf.constant(-threshold)), lambda:f2())], default=lambda:f3())
63+
# return tf.case([(tf.greater(x, threshold), lambda:f1())])
5064
def mish(x):
5165
return tf.keras.layers.Lambda(lambda x: x*tf.tanh(tf.math.log(1+tf.exp(x))))(x)
66+
# return tf.keras.layers.Lambda(lambda x: softplus(x))(x)
67+
# return tf.keras.layers.Lambda(lambda x: x * tf.tanh(softplus(x)))(x)
5268

5369
def residual_block(input_layer, input_channel, filter_num1, filter_num2):
5470
short_cut = input_layer

core/config.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# Set the class name
1515
__C.YOLO.CLASSES = "./data/classes/coco.names"
1616
__C.YOLO.ANCHORS = "./data/anchors/coco_anchors.txt"
17-
__C.YOLO.ANCHORS_TINY = "./data/anchors/basline_tiny_anchors.txt"
17+
__C.YOLO.ANCHORS_TINY = "./data/anchors/basline_tiny_anchors.txt"
1818
__C.YOLO.STRIDES = [8, 16, 32]
1919
__C.YOLO.STRIDES_TINY = [16, 32]
20+
__C.YOLO.XYSCALE = [1.2, 1.1, 1.05]
2021
__C.YOLO.ANCHOR_PER_SCALE = 3
2122
__C.YOLO.IOU_LOSS_THRESH = 0.5
2223

@@ -26,8 +27,8 @@
2627

2728
__C.TRAIN.ANNOT_PATH = "./data/dataset/val2017.txt"
2829
__C.TRAIN.BATCH_SIZE = 4
29-
# __C.TRAIN.INPUT_SIZE = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
30-
__C.TRAIN.INPUT_SIZE = [416]
30+
__C.TRAIN.INPUT_SIZE = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
31+
# __C.TRAIN.INPUT_SIZE = [416]
3132
__C.TRAIN.DATA_AUG = True
3233
__C.TRAIN.LR_INIT = 1e-3
3334
__C.TRAIN.LR_END = 1e-6
@@ -44,7 +45,7 @@
4445
__C.TEST.INPUT_SIZE = 416
4546
__C.TEST.DATA_AUG = False
4647
__C.TEST.DECTECTED_IMAGE_PATH = "./data/detection/"
47-
__C.TEST.SCORE_THRESHOLD = 0.3
48+
__C.TEST.SCORE_THRESHOLD = 0.25
4849
__C.TEST.IOU_THRESHOLD = 0.5
4950

5051

core/utils.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
import colorsys
44
import numpy as np
5+
import tensorflow as tf
56
from core.config import cfg
67

78
def load_weights_tiny(model, weights_file):
@@ -220,6 +221,40 @@ def bboxes_iou(boxes1, boxes2):
220221

221222
return ious
222223

224+
def bboxes_ciou(boxes1, boxes2):
225+
226+
boxes1 = np.array(boxes1)
227+
boxes2 = np.array(boxes2)
228+
229+
left = np.maximum(boxes1[..., 0], boxes2[..., 0])
230+
up = np.maximum(boxes1[..., 1], boxes2[..., 1])
231+
right = np.maximum(boxes1[..., 2], boxes2[..., 2])
232+
down = np.maximum(boxes1[..., 3], boxes2[..., 3])
233+
234+
c = (right - left) * (right - left) + (up - down) * (up - down)
235+
iou = bboxes_iou(boxes1, boxes2)
236+
237+
ax = (boxes1[..., 0] + boxes1[..., 2]) / 2
238+
ay = (boxes1[..., 1] + boxes1[..., 3]) / 2
239+
bx = (boxes2[..., 0] + boxes2[..., 2]) / 2
240+
by = (boxes2[..., 1] + boxes2[..., 3]) / 2
241+
242+
u = (ax - bx) * (ax - bx) + (ay - by) * (ay - by)
243+
d = u/c
244+
245+
aw = boxes1[..., 2] - boxes1[..., 0]
246+
ah = boxes1[..., 3] - boxes1[..., 1]
247+
bw = boxes2[..., 2] - boxes2[..., 0]
248+
bh = boxes2[..., 3] - boxes2[..., 1]
249+
250+
ar_gt = bw/bh
251+
ar_pred = aw/ah
252+
253+
ar_loss = 4 / (np.pi * np.pi) * (np.arctan(ar_gt) - np.arctan(ar_pred)) * (np.arctan(ar_gt) - np.arctan(ar_pred))
254+
alpha = ar_loss / (1 - iou + ar_loss + 0.000001)
255+
ciou_term = d + alpha * ar_loss
256+
257+
return iou - ciou_term
223258

224259
def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):
225260
"""
@@ -258,7 +293,30 @@ def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):
258293

259294
return best_bboxes
260295

261-
296+
def diounms_sort(bboxes, iou_threshold, sigma=0.3, method='nms', beta_nms=0.6):
297+
best_bboxes = []
298+
return best_bboxes
299+
def postprocess_bbbox(pred_bbox, XYSCALE, ANCHORS, STRIDES):
300+
for i, pred in enumerate(pred_bbox):
301+
conv_shape = pred.shape
302+
output_size = conv_shape[1]
303+
conv_raw_dxdy = pred[:, :, :, :, 0:2]
304+
conv_raw_dwdh = pred[:, :, :, :, 2:4]
305+
xy_grid = np.meshgrid(np.arange(output_size), np.arange(output_size))
306+
xy_grid = np.expand_dims(np.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
307+
308+
xy_grid = np.tile(tf.expand_dims(xy_grid, axis=0), [1, 1, 1, 3, 1])
309+
xy_grid = xy_grid.astype(np.float)
310+
311+
# pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * STRIDES[i]
312+
pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * STRIDES[i]
313+
# pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i]) * STRIDES[i]
314+
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
315+
pred[:, :, :, :, 0:4] = tf.concat([pred_xy, pred_wh], axis=-1)
316+
317+
pred_bbox = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
318+
pred_bbox = tf.concat(pred_bbox, axis=0)
319+
return pred_bbox
262320
def postprocess_boxes(pred_bbox, org_img_shape, input_size, score_threshold):
263321

264322
valid_scale=[0, np.inf]

core/yolov4.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,32 @@ def bbox_iou(boxes1, boxes2):
178178

179179
return 1.0 * inter_area / union_area
180180

181+
def bbox_ciou(boxes1, boxes2):
182+
boxes1_coor = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
183+
boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
184+
boxes2_coor = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
185+
boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)
186+
187+
left = tf.maximum(boxes1_coor[..., 0], boxes2_coor[..., 0])
188+
up = tf.maximum(boxes1_coor[..., 1], boxes2_coor[..., 1])
189+
right = tf.maximum(boxes1_coor[..., 2], boxes2_coor[..., 2])
190+
down = tf.maximum(boxes1_coor[..., 3], boxes2_coor[..., 3])
191+
192+
c = (right - left) * (right - left) + (up - down) * (up - down)
193+
iou = bbox_iou(boxes1, boxes2)
194+
195+
u = (boxes1[..., 0] - boxes2[..., 0]) * (boxes1[..., 0] - boxes2[..., 0]) + (boxes1[..., 1] - boxes2[..., 1]) * (boxes1[..., 1] - boxes2[..., 1])
196+
d = u / c
197+
198+
ar_gt = boxes2[..., 2] / boxes2[..., 3]
199+
ar_pred = boxes1[..., 2] / boxes1[..., 3]
200+
201+
ar_loss = 4 / (np.pi * np.pi) * (tf.atan(ar_gt) - tf.atan(ar_pred)) * (tf.atan(ar_gt) - tf.atan(ar_pred))
202+
alpha = ar_loss / (1 - iou + ar_loss + 0.000001)
203+
ciou_term = d + alpha * ar_loss
204+
205+
return iou - ciou_term
206+
181207
def bbox_giou(boxes1, boxes2):
182208

183209
boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
@@ -209,7 +235,6 @@ def bbox_giou(boxes1, boxes2):
209235

210236
return giou
211237

212-
213238
def compute_loss(pred, conv, label, bboxes, i=0):
214239

215240
conv_shape = tf.shape(conv)

detect.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def main(_argv):
2424
else:
2525
STRIDES = np.array(cfg.YOLO.STRIDES)
2626
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS, FLAGS.tiny)
27+
XYSCALE = cfg.YOLO.XYSCALE
2728
input_size = FLAGS.size
2829
image_path = FLAGS.image
2930

@@ -69,26 +70,9 @@ def main(_argv):
6970
interpreter.invoke()
7071
pred_bbox = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
7172

72-
for i, pred in enumerate(pred_bbox):
73-
conv_shape = pred.shape
74-
output_size = conv_shape[1]
75-
conv_raw_dxdy = pred[:, :, :, :, 0:2]
76-
conv_raw_dwdh = pred[:, :, :, :, 2:4]
77-
xy_grid = np.meshgrid(np.arange(output_size), np.arange(output_size))
78-
xy_grid = np.expand_dims(np.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
79-
80-
xy_grid = np.tile(tf.expand_dims(xy_grid, axis=0), [1, 1, 1, 3, 1])
81-
xy_grid = xy_grid.astype(np.float)
82-
83-
pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * STRIDES[i]
84-
# pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i]) * STRIDES[i]
85-
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
86-
pred[:, :, :, :, 0:4] = tf.concat([pred_xy, pred_wh], axis=-1)
87-
88-
pred_bbox = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
89-
pred_bbox = tf.concat(pred_bbox, axis=0)
90-
bboxes = utils.postprocess_boxes(pred_bbox, original_image_size, input_size, 0.3)
91-
bboxes = utils.nms(bboxes, 0.45, method='nms')
73+
pred_bbox = utils.postprocess_bbbox(pred_bbox, XYSCALE, ANCHORS, STRIDES)
74+
bboxes = utils.postprocess_boxes(pred_bbox, original_image_size, input_size, 0.25)
75+
bboxes = utils.nms(bboxes, 0.213, method='nms')
9276

9377
image = utils.draw_bbox(original_image, bboxes)
9478
image = Image.fromarray(image)

requirements-gpu.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ tqdm
55
absl-py
66
matplotlib
77
easydict
8+
pillow
89
tensorflow_addons==0.9.1

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ tensorflow==2.1.0
55
absl-py
66
easydict
77
matplotlib
8+
pillow
89
tensorflow_addons==0.9.1

0 commit comments

Comments
 (0)