Skip to content

Commit d65037c

Browse files
committed
apply eta calculator, exit training if pretrained iteration count is greater equal then target iterations
1 parent 65ed6e6 commit d65037c

File tree

2 files changed

+116
-2
lines changed

2 files changed

+116
-2
lines changed

eta.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
Authors : inzapp
3+
4+
Github url : https://github.com/inzapp/eta-calculator
5+
6+
Copyright (c) 2023 Inzapp
7+
8+
Permission is hereby granted, free of charge, to any person obtaining
9+
a copy of this software and associated documentation files (the
10+
"Software"), to deal in the Software without restriction, including
11+
without limitation the rights to use, copy, modify, merge, publish,
12+
distribute, sublicense, and/or sell copies of the Software, and to
13+
permit persons to whom the Software is furnished to do so, subject to
14+
the following conditions:
15+
16+
The above copyright notice and this permission notice shall be
17+
included in all copies or substantial portions of the Software.
18+
19+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
20+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
21+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
22+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
23+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
24+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
25+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
26+
"""
27+
from time import perf_counter
28+
29+
30+
class ETACalculator:
31+
def __init__(self, iterations, start_iteration=0, buffer_size=100):
32+
self.iterations = iterations
33+
self.start_iteration = start_iteration
34+
self.buffer_size = buffer_size
35+
self.start_time = 0
36+
self.recent_times = []
37+
self.recent_iterations = []
38+
39+
def start(self):
40+
self.start_time = perf_counter()
41+
self.recent_iterations.append(self.start_iteration)
42+
self.recent_times.append(perf_counter())
43+
44+
def end(self):
45+
avg_ips = float(self.iterations - self.start_iteration) / (perf_counter() - self.start_time)
46+
elapsed_time = self.convert_to_time_str(int(perf_counter() - self.start_time))
47+
return avg_ips, elapsed_time
48+
49+
def reset(self):
50+
self.recent_times = []
51+
self.recent_iterations = []
52+
53+
def update_buffer(self, iteration_count):
54+
self.recent_times.append(perf_counter())
55+
self.recent_iterations.append(iteration_count)
56+
if len(self.recent_times) > self.buffer_size:
57+
self.recent_times.pop(0)
58+
if len(self.recent_iterations) > self.buffer_size:
59+
self.recent_iterations.pop(0)
60+
61+
def convert_to_time_str(self, total_sec):
62+
times = []
63+
hh = total_sec // 3600
64+
times.append(str(hh).rjust(2, '0'))
65+
total_sec %= 3600
66+
mm = total_sec // 60
67+
times.append(str(mm).rjust(2, '0'))
68+
total_sec %= 60
69+
ss = total_sec
70+
times.append(str(ss).rjust(2, '0'))
71+
return ':'.join(times)
72+
73+
def update(self, iteration_count, return_values=False):
74+
self.update_buffer(iteration_count)
75+
elapsed_sec = self.recent_times[-1] - self.recent_times[0]
76+
total_iterations = iteration_count - self.recent_iterations[0]
77+
ips = total_iterations / elapsed_sec
78+
eta = (self.iterations - iteration_count) / ips
79+
elapsed_time = perf_counter() - self.start_time
80+
per = int(iteration_count / float(self.iterations) * 1000.0) / 10.0
81+
eta_str = self.convert_to_time_str(int(eta))
82+
elapsed_time_str= self.convert_to_time_str(int(elapsed_time))
83+
progress_str = f'[Iteration: {iteration_count}/{self.iterations}({per:.1f}%), {ips:.2f}it/s, {elapsed_time_str}<{eta_str}]'
84+
if return_values:
85+
return eta, ips, elapsed_time, per, progress_str
86+
else:
87+
return progress_str
88+
89+
90+
if __name__ == '__main__':
91+
import shutil as sh
92+
from time import sleep
93+
total_iterations = 500
94+
iteration_count = 0
95+
eta_calculator = ETACalculator(iterations=total_iterations, start_iteration=iteration_count)
96+
eta_calculator.start()
97+
while True:
98+
sleep(0.01)
99+
iteration_count += 1
100+
progress_str = eta_calculator.update(iteration_count)
101+
print(progress_str)
102+
if iteration_count == total_iterations:
103+
break
104+
avg_ips, elapsed_time = eta_calculator.end()
105+
eta_calculator.reset()
106+
print(f'\ntotal {total_iterations} iterations end successfully with avg IPS {avg_ips:.1f}, elapsed time : {elapsed_time}')
107+

sigmoid_classifier.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from glob import glob
3030
from tqdm import tqdm
3131
from model import Model
32+
from eta import ETACalculator
3233
from live_plot import LivePlot
3334
from generator import DataGenerator
3435
from lr_scheduler import LRScheduler
@@ -246,6 +247,9 @@ def draw_cam(self, x, label, window_size_h=512, alpha=0.6):
246247
cv2.waitKey(1)
247248

248249
def train(self):
250+
if self.pretrained_iteration_count >= self.iterations:
251+
print(f'pretrained iteration count {self.pretrained_iteration_count} is greater or equal than target iterations {self.iterations}')
252+
exit(0)
249253
self.model.summary()
250254
print(f'\ntrain on {len(self.train_image_paths)} samples')
251255
print(f'validate on {len(self.validation_image_paths)} samples\n')
@@ -254,6 +258,8 @@ def train(self):
254258
lr_scheduler = LRScheduler(lr=self.lr, iterations=self.iterations, warm_up=self.warm_up, policy=self.lr_policy)
255259
self.init_checkpoint_dir()
256260
iteration_count = self.pretrained_iteration_count
261+
eta_calculator = ETACalculator(iterations=self.iterations, start_iteration=iteration_count)
262+
eta_calculator.start()
257263
while True:
258264
for idx, (batch_x, batch_y) in enumerate(self.train_data_generator.flow()):
259265
lr_scheduler.update(optimizer, iteration_count)
@@ -274,7 +280,8 @@ def train(self):
274280
if self.live_loss_plot_flag:
275281
self.live_loss_plot.update(loss)
276282
iteration_count += 1
277-
print(f'\r[iteration count : {iteration_count:6d}] loss => {loss:.4f}', end='')
283+
progress_str = eta_calculator.update(iteration_count)
284+
print(f'\r{progress_str} loss => {loss:.4f}', end='')
278285
if iteration_count % 2000 == 0:
279286
self.save_last_model(self.model, iteration_count)
280287
if iteration_count == self.iterations:
@@ -287,7 +294,7 @@ def train(self):
287294
self.save_model(iteration_count)
288295

289296
def save_model(self, iteration_count):
290-
print(f'iteration count : {iteration_count}')
297+
print()
291298
if self.validation_data_generator.flow() is None:
292299
self.save_last_model(self.model, iteration_count)
293300
else:

0 commit comments

Comments
 (0)