Skip to content

Commit 250eb17

Browse files
committed
Use python logging in training.
This way, we get the training logs in the experiment_root too! TODO: Maybe also do that in embed and eval?
1 parent 0e30b89 commit 250eb17

File tree

2 files changed

+225
-17
lines changed

2 files changed

+225
-17
lines changed

common.py

+199
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
""" A bunch of general utilities shared by train/embed/eval """
22

33
from argparse import ArgumentTypeError
4+
import logging
45
import os
56

67
import numpy as np
@@ -154,3 +155,201 @@ def fid_to_image(fid, pid, image_root, image_size):
154155
image_resized = tf.image.resize_images(image_decoded, image_size)
155156

156157
return image_resized, fid, pid
158+
159+
160+
def get_logging_dict(name):
161+
return {
162+
'version': 1,
163+
'disable_existing_loggers': False,
164+
'formatters': {
165+
'standard': {
166+
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
167+
},
168+
},
169+
'handlers': {
170+
'stderr': {
171+
'level': 'INFO',
172+
'formatter': 'standard',
173+
'class': 'common.ColorStreamHandler',
174+
'stream': 'ext://sys.stderr',
175+
},
176+
'logfile': {
177+
'level': 'DEBUG',
178+
'formatter': 'standard',
179+
'class': 'logging.FileHandler',
180+
'filename': name + '.log',
181+
'mode': 'a',
182+
}
183+
},
184+
'loggers': {
185+
'': {
186+
'handlers': ['stderr', 'logfile'],
187+
'level': 'DEBUG',
188+
'propagate': True
189+
},
190+
191+
# extra ones to shut up.
192+
'tensorflow': {
193+
'handlers': ['stderr', 'logfile'],
194+
'level': 'INFO',
195+
},
196+
}
197+
}
198+
199+
200+
# Source for the remainder: https://gist.github.com/mooware/a1ed40987b6cc9ab9c65
201+
# Fixed some things mentioned in the comments there.
202+
203+
# colored stream handler for python logging framework (use the ColorStreamHandler class).
204+
#
205+
# based on:
206+
# http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output/1336640#1336640
207+
208+
# how to use:
209+
# i used a dict-based logging configuration, not sure what else would work.
210+
#
211+
# import logging, logging.config, colorstreamhandler
212+
#
213+
# _LOGCONFIG = {
214+
# "version": 1,
215+
# "disable_existing_loggers": False,
216+
#
217+
# "handlers": {
218+
# "console": {
219+
# "class": "colorstreamhandler.ColorStreamHandler",
220+
# "stream": "ext://sys.stderr",
221+
# "level": "INFO"
222+
# }
223+
# },
224+
#
225+
# "root": {
226+
# "level": "INFO",
227+
# "handlers": ["console"]
228+
# }
229+
# }
230+
#
231+
# logging.config.dictConfig(_LOGCONFIG)
232+
# mylogger = logging.getLogger("mylogger")
233+
# mylogger.warning("foobar")
234+
235+
# Copyright (c) 2014 Markus Pointner
236+
#
237+
# Permission is hereby granted, free of charge, to any person obtaining a copy
238+
# of this software and associated documentation files (the "Software"), to deal
239+
# in the Software without restriction, including without limitation the rights
240+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
241+
# copies of the Software, and to permit persons to whom the Software is
242+
# furnished to do so, subject to the following conditions:
243+
#
244+
# The above copyright notice and this permission notice shall be included in
245+
# all copies or substantial portions of the Software.
246+
#
247+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
248+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
249+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
250+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
251+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
252+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
253+
# THE SOFTWARE.
254+
255+
class _AnsiColorStreamHandler(logging.StreamHandler):
256+
DEFAULT = '\x1b[0m'
257+
RED = '\x1b[31m'
258+
GREEN = '\x1b[32m'
259+
YELLOW = '\x1b[33m'
260+
CYAN = '\x1b[36m'
261+
262+
CRITICAL = RED
263+
ERROR = RED
264+
WARNING = YELLOW
265+
INFO = DEFAULT # GREEN
266+
DEBUG = CYAN
267+
268+
@classmethod
269+
def _get_color(cls, level):
270+
if level >= logging.CRITICAL: return cls.CRITICAL
271+
elif level >= logging.ERROR: return cls.ERROR
272+
elif level >= logging.WARNING: return cls.WARNING
273+
elif level >= logging.INFO: return cls.INFO
274+
elif level >= logging.DEBUG: return cls.DEBUG
275+
else: return cls.DEFAULT
276+
277+
def __init__(self, stream=None):
278+
logging.StreamHandler.__init__(self, stream)
279+
280+
def format(self, record):
281+
text = logging.StreamHandler.format(self, record)
282+
color = self._get_color(record.levelno)
283+
return (color + text + self.DEFAULT) if self.is_tty() else text
284+
285+
def is_tty(self):
286+
isatty = getattr(self.stream, 'isatty', None)
287+
return isatty and isatty()
288+
289+
290+
class _WinColorStreamHandler(logging.StreamHandler):
291+
# wincon.h
292+
FOREGROUND_BLACK = 0x0000
293+
FOREGROUND_BLUE = 0x0001
294+
FOREGROUND_GREEN = 0x0002
295+
FOREGROUND_CYAN = 0x0003
296+
FOREGROUND_RED = 0x0004
297+
FOREGROUND_MAGENTA = 0x0005
298+
FOREGROUND_YELLOW = 0x0006
299+
FOREGROUND_GREY = 0x0007
300+
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
301+
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
302+
303+
BACKGROUND_BLACK = 0x0000
304+
BACKGROUND_BLUE = 0x0010
305+
BACKGROUND_GREEN = 0x0020
306+
BACKGROUND_CYAN = 0x0030
307+
BACKGROUND_RED = 0x0040
308+
BACKGROUND_MAGENTA = 0x0050
309+
BACKGROUND_YELLOW = 0x0060
310+
BACKGROUND_GREY = 0x0070
311+
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
312+
313+
DEFAULT = FOREGROUND_WHITE
314+
CRITICAL = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
315+
ERROR = FOREGROUND_RED | FOREGROUND_INTENSITY
316+
WARNING = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
317+
INFO = FOREGROUND_GREEN
318+
DEBUG = FOREGROUND_CYAN
319+
320+
@classmethod
321+
def _get_color(cls, level):
322+
if level >= logging.CRITICAL: return cls.CRITICAL
323+
elif level >= logging.ERROR: return cls.ERROR
324+
elif level >= logging.WARNING: return cls.WARNING
325+
elif level >= logging.INFO: return cls.INFO
326+
elif level >= logging.DEBUG: return cls.DEBUG
327+
else: return cls.DEFAULT
328+
329+
def _set_color(self, code):
330+
import ctypes
331+
ctypes.windll.kernel32.SetConsoleTextAttribute(self._outhdl, code)
332+
333+
def __init__(self, stream=None):
334+
logging.StreamHandler.__init__(self, stream)
335+
# get file handle for the stream
336+
import ctypes, ctypes.util
337+
# for some reason find_msvcrt() sometimes doesn't find msvcrt.dll on my system?
338+
crtname = ctypes.util.find_msvcrt()
339+
if not crtname:
340+
crtname = ctypes.util.find_library("msvcrt")
341+
crtlib = ctypes.cdll.LoadLibrary(crtname)
342+
self._outhdl = crtlib._get_osfhandle(self.stream.fileno())
343+
344+
def emit(self, record):
345+
color = self._get_color(record.levelno)
346+
self._set_color(color)
347+
logging.StreamHandler.emit(self, record)
348+
self._set_color(self.FOREGROUND_WHITE)
349+
350+
# select ColorStreamHandler based on platform
351+
import platform
352+
if platform.system() == 'Windows':
353+
ColorStreamHandler = _WinColorStreamHandler
354+
else:
355+
ColorStreamHandler = _AnsiColorStreamHandler

train.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from argparse import ArgumentParser
33
from datetime import timedelta
44
from importlib import import_module
5+
import logging.config
56
import os
67
from signal import SIGINT, SIGTERM
78
import sys
@@ -195,9 +196,9 @@ def main():
195196
# If the experiment directory exists already, we bail in fear.
196197
if os.path.exists(args.experiment_root):
197198
if os.listdir(args.experiment_root):
198-
print('The directory {} already exists and is not empty. If '
199-
'you want to resume training, append --resume to your '
200-
'call.'.format(args.experiment_root))
199+
print('The directory {} already exists and is not empty.'
200+
' If you want to resume training, append --resume to'
201+
' your call.'.format(args.experiment_root))
201202
exit(1)
202203
else:
203204
os.makedirs(args.experiment_root)
@@ -207,19 +208,23 @@ def main():
207208
with open(args_file, 'w') as f:
208209
json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)
209210

211+
log_file = os.path.join(args.experiment_root, "train")
212+
logging.config.dictConfig(common.get_logging_dict(log_file))
213+
log = logging.getLogger('train')
214+
210215
# Also show all parameter values at the start, for ease of reading logs.
211-
print('Training using the following parameters:')
216+
log.info('Training using the following parameters:')
212217
for key, value in sorted(vars(args).items()):
213-
print('{}: {}'.format(key, value))
218+
log.info('{}: {}'.format(key, value))
214219

215220
# Check them here, so they are not required when --resume-ing.
216221
if not args.train_set:
217222
parser.print_help()
218-
print("You did not specify the `train_set` argument!")
223+
log.error("You did not specify the `train_set` argument!")
219224
sys.exit(1)
220225
if not args.image_root:
221226
parser.print_help()
222-
print("You did not specify the required `image_root` argument!")
227+
log.error("You did not specify the required `image_root` argument!")
223228
sys.exit(1)
224229

225230
# Load the data from the CSV file.
@@ -351,7 +356,7 @@ def main():
351356
if args.resume:
352357
# In case we're resuming, simply load the full checkpoint to init.
353358
last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
354-
print('Restoring from checkpoint: {}'.format(last_checkpoint))
359+
log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
355360
checkpoint_saver.restore(sess, last_checkpoint)
356361
else:
357362
# But if we're starting from scratch, we may need to load some
@@ -370,7 +375,7 @@ def main():
370375
summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph)
371376

372377
start_step = sess.run(global_step)
373-
print('Starting training from iteration {}.'.format(start_step))
378+
log.info('Starting training from iteration {}.'.format(start_step))
374379

375380
# Finally, here comes the main-loop. This `Uninterrupt` is a handy
376381
# utility such that an iteration still finishes on Ctrl+C and we can
@@ -397,13 +402,17 @@ def main():
397402

398403
# Do a huge print out of the current progress.
399404
seconds_todo = (args.train_iterations - step) * elapsed_time
400-
print('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
401-
'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
402-
step, float(np.min(b_loss)), float(np.mean(b_loss)),
403-
float(np.max(b_loss)),
404-
args.batch_k-1, float(b_prec_at_k),
405-
timedelta(seconds=int(seconds_todo)), elapsed_time),
406-
flush=True)
405+
log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
406+
'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
407+
step,
408+
float(np.min(b_loss)),
409+
float(np.mean(b_loss)),
410+
float(np.max(b_loss)),
411+
args.batch_k-1, float(b_prec_at_k),
412+
timedelta(seconds=int(seconds_todo)),
413+
elapsed_time))
414+
sys.stdout.flush()
415+
sys.stderr.flush()
407416

408417
# Save a checkpoint of training every so often.
409418
if (args.checkpoint_frequency > 0 and
@@ -413,7 +422,7 @@ def main():
413422

414423
# Stop the main-loop at the end of the step, if requested.
415424
if u.interrupted:
416-
print("Interrupted on request!")
425+
log.info("Interrupted on request!")
417426
break
418427

419428
# Store one final checkpoint. This might be redundant, but it is crucial

0 commit comments

Comments
 (0)