Skip to content

Commit e2b17bd

Browse files
committed
update lr scheduler
1 parent 2721614 commit e2b17bd

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

lr_scheduler.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(self,
3131
iterations,
3232
lr,
3333
policy,
34-
min_lr=0.0,
3534
warm_up=0.1,
3635
min_momentum=0.85,
3736
max_momentum=0.95,
@@ -40,15 +39,13 @@ def __init__(self,
4039
decay_step=0.1,
4140
double_step=True):
4241
assert 0.0 <= lr <= 1.0
43-
assert 0.0 <= min_lr <= 1.0
4442
assert 0.0 <= warm_up <= 1.0
4543
assert 0.0 <= min_momentum <= 1.0
4644
assert 0.0 <= max_momentum <= 1.0
4745
assert 0.0 <= decay_step <= 1.0
4846
assert policy in ['constant', 'step', 'cosine', 'onecycle']
4947
self.lr = lr
5048
self.policy = policy
51-
self.min_lr = min_lr
5249
self.max_lr = self.lr
5350
self.warm_up = warm_up
5451
self.min_momentum = min_momentum
@@ -57,6 +54,7 @@ def __init__(self,
5754
self.cycle_length = initial_cycle_length
5855
self.cycle_weight = cycle_weight
5956
self.decay_step = decay_step
57+
self.min_lr = self.lr * self.decay_step ** 2.0
6058
self.double_step = double_step
6159
self.cycle_step = 0
6260

@@ -92,7 +90,7 @@ def __schedule_step_decay(self, optimizer, iteration_count):
9290
if warm_up_iteration > 0 and iteration_count <= warm_up_iteration:
9391
lr = self.__warm_up_lr(iteration_count, warm_up_iteration)
9492
elif self.double_step and iteration_count >= int(self.iterations * 0.9):
95-
lr = self.lr * self.decay_step * self.decay_step
93+
lr = self.lr * self.decay_step ** 2.0
9694
elif iteration_count >= int(self.iterations * 0.8):
9795
lr = self.lr * self.decay_step
9896
else:
@@ -101,19 +99,19 @@ def __schedule_step_decay(self, optimizer, iteration_count):
10199
return lr
102100

103101
def __schedule_one_cycle(self, optimizer, iteration_count):
104-
warm_up = 0.3
105-
min_lr = self.min_lr
102+
min_lr = 0.0
106103
max_lr = self.max_lr
107104
min_mm = self.min_momentum
108105
max_mm = self.max_momentum
109-
warm_up_iterations = int(self.iterations * warm_up)
110-
if iteration_count <= warm_up_iterations:
106+
warm_up_iterations = int(self.iterations * self.warm_up)
107+
if warm_up_iterations > 0 and iteration_count <= warm_up_iterations:
111108
iterations = warm_up_iterations
112109
lr = ((np.cos(((iteration_count * np.pi) / iterations) + np.pi) + 1.0) * 0.5) * (max_lr - min_lr) + min_lr # increase only until target iterations
113110
mm = ((np.cos(((iteration_count * np.pi) / iterations) + 0.0) + 1.0) * 0.5) * (max_mm - min_mm) + min_mm # decrease only until target iterations
114111
self.__set_lr(optimizer, lr)
115112
self.__set_momentum(optimizer, mm)
116113
else:
114+
min_lr = self.min_lr
117115
iteration_count -= warm_up_iterations + 1
118116
iterations = self.iterations - warm_up_iterations
119117
lr = ((np.cos(((iteration_count * np.pi) / iterations) + 0.0) + 1.0) * 0.5) * (max_lr - min_lr) + min_lr # decrease only until target iterations
@@ -140,7 +138,7 @@ def plot_lr(policy):
140138
import tensorflow as tf
141139
from matplotlib import pyplot as plt
142140
lr = 0.001
143-
warm_up = 0.1
141+
warm_up = 0.3
144142
decay_step = 0.2
145143
iterations = 37500
146144
iterations = int(iterations / (1.0 - warm_up))

0 commit comments

Comments
 (0)