9
9
10
10
11
11
class YoloDatasets (keras .utils .Sequence ):
12
- def __init__ (self , annotation_lines , input_shape , anchors , batch_size , num_classes , anchors_mask , epoch_now , epoch_length , mosaic , train , mosaic_ratio = 0.7 ):
12
+ def __init__ (self , annotation_lines , input_shape , anchors , batch_size , num_classes , anchors_mask , epoch_now , epoch_length , \
13
+ mosaic , mixup , mosaic_prob , mixup_prob , train , special_aug_ratio = 0.7 ):
13
14
self .annotation_lines = annotation_lines
14
15
self .length = len (self .annotation_lines )
15
16
@@ -21,10 +22,11 @@ def __init__(self, annotation_lines, input_shape, anchors, batch_size, num_class
21
22
self .epoch_now = epoch_now - 1
22
23
self .epoch_length = epoch_length
23
24
self .mosaic = mosaic
25
+ self .mosaic_prob = mosaic_prob
26
+ self .mixup = mixup
27
+ self .mixup_prob = mixup_prob
24
28
self .train = train
25
- self .mosaic_ratio = mosaic_ratio
26
-
27
- self .threshold = 4
29
+ self .special_aug_ratio = special_aug_ratio
28
30
29
31
def __len__ (self ):
30
32
return math .ceil (len (self .annotation_lines ) / float (self .batch_size ))
@@ -38,14 +40,16 @@ def __getitem__(self, index):
38
40
# 训练时进行数据的随机增强
39
41
# 验证时不进行数据的随机增强
40
42
#---------------------------------------------------#
41
- if self .mosaic :
42
- if self .rand () < 0.5 and self .epoch_now < self .epoch_length * self .mosaic_ratio :
43
- lines = sample (self .annotation_lines , 3 )
44
- lines .append (self .annotation_lines [i ])
45
- shuffle (lines )
46
- image , box = self .get_random_data_with_Mosaic (lines , self .input_shape )
47
- else :
48
- image , box = self .get_random_data (self .annotation_lines [i ], self .input_shape , random = self .train )
43
+ if self .mosaic and self .rand () < self .mosaic_prob and self .epoch_now < self .epoch_length * self .special_aug_ratio :
44
+ lines = sample (self .annotation_lines , 3 )
45
+ lines .append (self .annotation_lines [i ])
46
+ shuffle (lines )
47
+ image , box = self .get_random_data_with_Mosaic (lines , self .input_shape )
48
+
49
+ if self .mixup and self .rand () < self .mixup_prob :
50
+ lines = sample (self .annotation_lines , 1 )
51
+ image_2 , box_2 = self .get_random_data (lines [0 ], self .input_shape , random = self .train )
52
+ image , box = self .get_random_data_with_MixUp (image , box , image_2 , box_2 )
49
53
else :
50
54
image , box = self .get_random_data (self .annotation_lines [i ], self .input_shape , random = self .train )
51
55
image_data .append (preprocess_input (np .array (image , np .float32 )))
@@ -368,6 +372,25 @@ def get_random_data_with_Mosaic(self, annotation_line, input_shape, max_boxes=50
368
372
box_data [:len (new_boxes )] = new_boxes
369
373
return new_image , box_data
370
374
375
+ def get_random_data_with_MixUp (self , image_1 , box_1 , image_2 , box_2 , max_boxes = 500 ):
376
+ new_image = np .array (image_1 , np .float32 ) * 0.5 + np .array (image_2 , np .float32 ) * 0.5
377
+
378
+ box_1_wh = box_1 [:, 2 :4 ] - box_1 [:, 0 :2 ]
379
+ box_1_valid = box_1_wh [:, 0 ] > 0
380
+
381
+ box_2_wh = box_2 [:, 2 :4 ] - box_2 [:, 0 :2 ]
382
+ box_2_valid = box_2_wh [:, 0 ] > 0
383
+
384
+ new_boxes = np .concatenate ([box_1 [box_1_valid , :], box_2 [box_2_valid , :]], axis = 0 )
385
+ #---------------------------------#
386
+ # 将box进行调整
387
+ #---------------------------------#
388
+ box_data = np .zeros ((max_boxes , 5 ))
389
+ if len (new_boxes )> 0 :
390
+ if len (new_boxes )> max_boxes : new_boxes = new_boxes [:max_boxes ]
391
+ box_data [:len (new_boxes )] = new_boxes
392
+ return new_image , box_data
393
+
371
394
def preprocess_true_boxes (self , true_boxes , input_shape , anchors , num_classes ):
372
395
assert (true_boxes [..., 4 ]< num_classes ).all (), 'class id must be less than num_classes'
373
396
#-----------------------------------------------------------#
0 commit comments