@@ -56,7 +56,6 @@ def __init__(self,
56
56
aug_h_flip ,
57
57
lr_policy = 'step' ,
58
58
model_name = 'model' ,
59
- auto_balance = False ,
60
59
live_loss_plot = False ,
61
60
checkpoint_interval = 0 ,
62
61
show_class_activation_map = False ,
@@ -73,7 +72,6 @@ def __init__(self,
73
72
self .batch_size = batch_size
74
73
self .iterations = iterations
75
74
self .lr_policy = lr_policy
76
- self .auto_balance = auto_balance
77
75
self .live_loss_plot_flag = live_loss_plot
78
76
self .max_val_acc = 0.0
79
77
self .show_class_activation_map = show_class_activation_map
@@ -87,8 +85,8 @@ def __init__(self,
87
85
train_image_path = self .unify_path (train_image_path )
88
86
validation_image_path = self .unify_path (validation_image_path )
89
87
90
- self .train_image_paths , train_class_names , self . class_weights , _ = self .init_image_paths (train_image_path )
91
- self .validation_image_paths , validation_class_names , _ , self .include_unknown = self .init_image_paths (validation_image_path )
88
+ self .train_image_paths , train_class_names , _ = self .init_image_paths (train_image_path )
89
+ self .validation_image_paths , validation_class_names , self .include_unknown = self .init_image_paths (validation_image_path )
92
90
if len (self .train_image_paths ) == 0 :
93
91
print (f'no images in train_image_path : { train_image_path } ' )
94
92
exit (0 )
@@ -184,21 +182,15 @@ def init_image_paths(self, image_path):
184
182
print ()
185
183
class_names = sorted (list (class_name_set ))
186
184
total_data_count = float (sum (class_counts )) + unknown_class_count
187
- class_weights = [1.0 - (count / total_data_count ) if self .auto_balance else 0.0 for count in class_counts ]
188
- return image_paths , class_names , class_weights , include_unknown
185
+ return image_paths , class_names , include_unknown
189
186
190
187
@tf .function
191
- def compute_gradient (self , model , optimizer , batch_x , y_true , loss_function , class_weights ):
188
+ def compute_gradient (self , model , optimizer , batch_x , y_true , loss_function ):
192
189
with tf .GradientTape () as tape :
193
190
y_pred = self .model (batch_x , training = True )
194
191
loss = loss_function (y_true , y_pred )
195
- batch_size = tf .cast (tf .shape (y_true )[0 ], dtype = tf .int32 )
196
- class_weights_t = tf .convert_to_tensor (class_weights , dtype = y_pred .dtype )
197
- if tf .reduce_sum (class_weights_t ) > 0.0 :
198
- class_weights_t = tf .repeat (tf .expand_dims (class_weights_t , axis = 0 ), repeats = batch_size , axis = 0 )
199
- class_weights_t = tf .where (y_true == 1.0 , class_weights_t , 1.0 - class_weights_t )
200
- loss *= class_weights_t
201
- loss = tf .reduce_sum (loss ) / tf .cast (batch_size , dtype = loss .dtype )
192
+ batch_size_f = tf .cast (tf .shape (y_true )[0 ], dtype = loss .dtype )
193
+ loss = tf .reduce_sum (loss ) / batch_size_f
202
194
gradients = tape .gradient (loss , model .trainable_variables )
203
195
optimizer .apply_gradients (zip (gradients , model .trainable_variables ))
204
196
return loss
@@ -246,6 +238,9 @@ def draw_cam(self, x, label, window_size_h=512, alpha=0.6):
246
238
cv2 .imshow ('cam' , image_grid )
247
239
cv2 .waitKey (1 )
248
240
241
+ def print_loss (self , progress_str , loss ):
242
+ print (f'\r { progress_str } loss => { loss :.4f} ' , end = '' )
243
+
249
244
def train (self ):
250
245
if self .pretrained_iteration_count >= self .iterations :
251
246
print (f'pretrained iteration count { self .pretrained_iteration_count } is greater or equal than target iterations { self .iterations } ' )
@@ -263,7 +258,7 @@ def train(self):
263
258
while True :
264
259
for idx , (batch_x , batch_y ) in enumerate (self .train_data_generator .flow ()):
265
260
lr_scheduler .update (optimizer , iteration_count )
266
- loss = self .compute_gradient (self .model , optimizer , batch_x , batch_y , loss_function , self . class_weights )
261
+ loss = self .compute_gradient (self .model , optimizer , batch_x , batch_y , loss_function )
267
262
if self .show_class_activation_map and iteration_count % 100 == 0 :
268
263
try_count = 0
269
264
while True :
@@ -281,7 +276,7 @@ def train(self):
281
276
self .live_loss_plot .update (loss )
282
277
iteration_count += 1
283
278
progress_str = eta_calculator .update (iteration_count )
284
- print ( f' \r { progress_str } loss => { loss :.4f } ' , end = '' )
279
+ self . print_loss ( progress_str , loss )
285
280
if iteration_count % 2000 == 0 :
286
281
self .save_last_model (self .model , iteration_count )
287
282
if iteration_count == self .iterations :
0 commit comments