Skip to content

Commit 56581e6

Browse files
committed
remove auto_balance parameter, split loss print to function
1 parent e4547a8 commit 56581e6

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

sigmoid_classifier.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(self,
5656
aug_h_flip,
5757
lr_policy='step',
5858
model_name='model',
59-
auto_balance=False,
6059
live_loss_plot=False,
6160
checkpoint_interval=0,
6261
show_class_activation_map=False,
@@ -73,7 +72,6 @@ def __init__(self,
7372
self.batch_size = batch_size
7473
self.iterations = iterations
7574
self.lr_policy = lr_policy
76-
self.auto_balance = auto_balance
7775
self.live_loss_plot_flag = live_loss_plot
7876
self.max_val_acc = 0.0
7977
self.show_class_activation_map = show_class_activation_map
@@ -87,8 +85,8 @@ def __init__(self,
8785
train_image_path = self.unify_path(train_image_path)
8886
validation_image_path = self.unify_path(validation_image_path)
8987

90-
self.train_image_paths, train_class_names, self.class_weights, _ = self.init_image_paths(train_image_path)
91-
self.validation_image_paths, validation_class_names, _, self.include_unknown = self.init_image_paths(validation_image_path)
88+
self.train_image_paths, train_class_names, _ = self.init_image_paths(train_image_path)
89+
self.validation_image_paths, validation_class_names, self.include_unknown = self.init_image_paths(validation_image_path)
9290
if len(self.train_image_paths) == 0:
9391
print(f'no images in train_image_path : {train_image_path}')
9492
exit(0)
@@ -184,21 +182,15 @@ def init_image_paths(self, image_path):
184182
print()
185183
class_names = sorted(list(class_name_set))
186184
total_data_count = float(sum(class_counts)) + unknown_class_count
187-
class_weights = [1.0 - (count / total_data_count) if self.auto_balance else 0.0 for count in class_counts]
188-
return image_paths, class_names, class_weights, include_unknown
185+
return image_paths, class_names, include_unknown
189186

190187
@tf.function
191-
def compute_gradient(self, model, optimizer, batch_x, y_true, loss_function, class_weights):
188+
def compute_gradient(self, model, optimizer, batch_x, y_true, loss_function):
192189
with tf.GradientTape() as tape:
193190
y_pred = self.model(batch_x, training=True)
194191
loss = loss_function(y_true, y_pred)
195-
batch_size = tf.cast(tf.shape(y_true)[0], dtype=tf.int32)
196-
class_weights_t = tf.convert_to_tensor(class_weights, dtype=y_pred.dtype)
197-
if tf.reduce_sum(class_weights_t) > 0.0:
198-
class_weights_t = tf.repeat(tf.expand_dims(class_weights_t, axis=0), repeats=batch_size, axis=0)
199-
class_weights_t = tf.where(y_true == 1.0, class_weights_t, 1.0 - class_weights_t)
200-
loss *= class_weights_t
201-
loss = tf.reduce_sum(loss) / tf.cast(batch_size, dtype=loss.dtype)
192+
batch_size_f = tf.cast(tf.shape(y_true)[0], dtype=loss.dtype)
193+
loss = tf.reduce_sum(loss) / batch_size_f
202194
gradients = tape.gradient(loss, model.trainable_variables)
203195
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
204196
return loss
@@ -246,6 +238,9 @@ def draw_cam(self, x, label, window_size_h=512, alpha=0.6):
246238
cv2.imshow('cam', image_grid)
247239
cv2.waitKey(1)
248240

241+
def print_loss(self, progress_str, loss):
242+
print(f'\r{progress_str} loss => {loss:.4f}', end='')
243+
249244
def train(self):
250245
if self.pretrained_iteration_count >= self.iterations:
251246
print(f'pretrained iteration count {self.pretrained_iteration_count} is greater or equal than target iterations {self.iterations}')
@@ -263,7 +258,7 @@ def train(self):
263258
while True:
264259
for idx, (batch_x, batch_y) in enumerate(self.train_data_generator.flow()):
265260
lr_scheduler.update(optimizer, iteration_count)
266-
loss = self.compute_gradient(self.model, optimizer, batch_x, batch_y, loss_function, self.class_weights)
261+
loss = self.compute_gradient(self.model, optimizer, batch_x, batch_y, loss_function)
267262
if self.show_class_activation_map and iteration_count % 100 == 0:
268263
try_count = 0
269264
while True:
@@ -281,7 +276,7 @@ def train(self):
281276
self.live_loss_plot.update(loss)
282277
iteration_count += 1
283278
progress_str = eta_calculator.update(iteration_count)
284-
print(f'\r{progress_str} loss => {loss:.4f}', end='')
279+
self.print_loss(progress_str, loss)
285280
if iteration_count % 2000 == 0:
286281
self.save_last_model(self.model, iteration_count)
287282
if iteration_count == self.iterations:

0 commit comments

Comments
 (0)