Skip to content

Commit 829b9b0

Browse files
committed
change ale loss to ace, update default hyperparameters, use max iterations if zero checkpoint_interval
1 parent 3b8b0d5 commit 829b9b0

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

ale.py renamed to ace.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
Authors : inzapp
33
4-
Github url : https://github.com/inzapp/absolute-logarithmic-error
4+
Github url : https://github.com/inzapp/adaptive-crossentropy
55
6-
Copyright 2022 inzapp Authors. All Rights Reserved.
6+
Copyright 2023 inzapp Authors. All Rights Reserved.
77
88
Licensed under the Apache License, Version 2.0 (the "License"),
99
you may not use this file except in compliance with the License.
@@ -20,25 +20,23 @@
2020
import tensorflow as tf
2121

2222

23-
class AbsoluteLogarithmicError(tf.keras.losses.Loss):
24-
"""Computes the cross-entropy log scale loss between true labels and predicted labels.
23+
class AdaptiveCrossentropy(tf.keras.losses.Loss):
24+
"""Computes the adaptive cross-entropy loss between true labels and predicted labels.
2525
26-
This loss function can be used regardless of classification problem or regression problem.
27-
28-
See: https://github.com/inzapp/absolute-logarithmic-error
26+
See: https://github.com/inzapp/adaptive-crossentropy
2927
3028
Standalone usage:
3129
>>> y_true = [[0, 1], [0, 0]]
3230
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
33-
>>> ale = AbsoluteLogarithmicError()
34-
>>> loss = ale(y_true, y_pred)
31+
>>> ace = AdaptiveCrossentropy()
32+
>>> loss = ace(y_true, y_pred)
3533
>>> loss.numpy()
36-
array([[0.9162905, 0.9162905], [0.5108255, 0.9162905]], dtype=float32)
34+
array([[0.9162906 0.9162905], [0.5108254 0.9162906]], dtype=float32)
3735
3836
Usage:
39-
model.compile(optimizer='sgd', loss=AbsoluteLogarithmicError())
37+
model.compile(optimizer='sgd', loss=AdaptiveCrossentropy())
4038
"""
41-
def __init__(self, alpha=0.0, gamma=0.0, label_smoothing=0.0, reduce='none', name='AbsoluteLogarithmicError'):
39+
def __init__(self, alpha=0.0, gamma=0.0, label_smoothing=0.0, reduce='none', name='AdaptiveCrossentropy'):
4240
"""
4341
Args:
4442
alpha: Weight of the loss where not positive value positioned in y_true tensor.
@@ -79,14 +77,14 @@ def call(self, y_true, y_pred):
7977
eps = tf.cast(self.eps, y_pred.dtype)
8078
y_true_clip = tf.clip_by_value(y_true, self.label_smoothing, 1.0 - self.label_smoothing)
8179
y_pred_clip = tf.clip_by_value(y_pred, eps, 1.0 - eps)
82-
abs_error = tf.abs(y_true_clip - y_pred_clip)
83-
loss = -tf.math.log((1.0 + eps) - abs_error)
80+
loss = -((y_true * tf.math.log(y_pred + eps)) + ((1.0 - y_true) * tf.math.log(1.0 - y_pred + eps)))
8481
if self.alpha > 0.0:
8582
alpha = tf.ones_like(y_true) * self.alpha
8683
alpha = tf.where(y_true != 1.0, alpha, 1.0 - alpha)
8784
loss *= alpha
8885
if self.gamma >= 1.0:
89-
loss *= tf.pow(abs_error, self.gamma)
86+
adaptive_weight = tf.pow(tf.abs(y_true_clip - y_pred_clip), self.gamma)
87+
loss *= adaptive_weight
9088
if self.reduce == 'mean':
9189
loss = tf.reduce_mean(loss)
9290
elif self.reduce == 'sum':

sigmoid_classifier.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from live_plot import LivePlot
3535
from generator import DataGenerator
3636
from lr_scheduler import LRScheduler
37-
from ale import AbsoluteLogarithmicError
37+
from ace import AdaptiveCrossentropy
3838
from ckpt_manager import CheckpointManager
3939

4040

@@ -63,7 +63,7 @@ def __init__(self,
6363
cam_activation_layer_name='cam_activation',
6464
last_conv_layer_name='squeeze_conv'):
6565
super().__init__()
66-
assert checkpoint_interval >= 1000
66+
assert checkpoint_interval == 0 or checkpoint_interval >= 1000
6767
self.input_shape = input_shape
6868
self.lr = lr
6969
self.lrf = lrf
@@ -82,6 +82,8 @@ def __init__(self,
8282
self.pretrained_iteration_count = 0
8383
warnings.filterwarnings(action='ignore')
8484
self.set_model_name(model_name)
85+
if self.checkpoint_interval == 0:
86+
self.checkpoint_interval = self.iterations
8587

8688
train_image_path = self.unify_path(train_image_path)
8789
validation_image_path = self.unify_path(validation_image_path)
@@ -247,7 +249,7 @@ def train(self):
247249
print(f'\ntrain on {len(self.train_image_paths)} samples')
248250
print(f'validate on {len(self.validation_image_paths)} samples\n')
249251
optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=self.momentum)
250-
loss_function = AbsoluteLogarithmicError(alpha=self.alpha, gamma=self.gamma, label_smoothing=self.label_smoothing)
252+
loss_function = AdaptiveCrossentropy(alpha=self.alpha, gamma=self.gamma, label_smoothing=self.label_smoothing)
251253
lr_scheduler = LRScheduler(lr=self.lr, lrf=self.lrf, iterations=self.iterations, warm_up=self.warm_up, policy=self.lr_policy)
252254
self.init_checkpoint_dir()
253255
iteration_count = self.pretrained_iteration_count

train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,18 @@
3030
input_shape=(64, 64, 1),
3131
lr=0.001,
3232
lrf=0.05,
33-
alpha=0.5,
33+
alpha=0.0,
3434
gamma=2.0,
35-
warm_up=0.5,
35+
warm_up=0.1,
3636
momentum=0.9,
37-
batch_size=32,
37+
batch_size=64,
3838
iterations=1000000,
3939
label_smoothing=0.1,
4040
aug_brightness=0.3,
41-
aug_contrast=0.4,
42-
aug_rotate=20,
41+
aug_contrast=0.3,
42+
aug_rotate=15,
4343
aug_h_flip=False,
44-
checkpoint_interval=20000,
44+
checkpoint_interval=10000,
4545
show_class_activation_map=False)
4646

4747
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)