Skip to content

Commit

Permalink
Add tensorboard for args
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolin Guo committed Mar 7, 2021
1 parent 41330a4 commit 0b1d1ba
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 106 deletions.
2 changes: 1 addition & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ data
docker
notebook
src
tf_logs
tb_logs
222 changes: 117 additions & 105 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.tensorboard import SummaryWriter

from networks.dataset import load_dataset
from networks.models import get_model
Expand Down Expand Up @@ -63,9 +64,9 @@ def get_parser() -> argparse.ArgumentParser:
# "--ckpt-filepath", type=str, default=None,
# help="Checkpoint filepath to load and resume training from "
# "e.g. ./cp-001-50.51.ckpt.index")
#train_parser.add_argument(
# "--log-dir", type=str, default='/BrainSeg/tf_logs',
# help="Directory for saving tensorboard logs")
train_parser.add_argument(
"--log-dir", type=str, default='/Colorization/tb_logs',
help="Directory for saving tensorboard logs")
#train_parser.add_argument(
# "--file-suffix", type=str,
# default=datetime.now().strftime("%Y%m%d_%H%M%S"),
Expand All @@ -81,6 +82,17 @@ def get_parser() -> argparse.ArgumentParser:

return main_parser

def log_configs(log_dir: str, args) -> None:
"""Log configuration of this training session"""
#writer = tf.summary.create_file_writer(log_dir + "/config")
writer = SummaryWriter(log_dir + "/config")

for key, value in vars(args).items():
writer.add_text(str(key), str(value), step=0)

writer.flush()
writer.close() # FIXME: Cautious, unused in TF

def train(args) -> None:
"""Start training based on args input"""
# Check if GPU is available
Expand Down Expand Up @@ -115,112 +127,112 @@ def train(args) -> None:



class_names = ['Background', 'Gray Matter', 'White Matter']
model.compile(optimizer=optimizers.Adam(),
loss=get_loss_func(args.loss_func, class_weight,
gamma=args.focal_loss_gamma),
metrics=[metrics.SparseCategoricalAccuracy(),
SparseMeanIoU(num_classes=3, name='IoU/Mean'),
SparsePixelAccuracy(num_classes=3, name='PixelAcc'),
SparseMeanAccuracy(num_classes=3, name='MeanAcc'),
SparseFreqIoU(num_classes=3, name='IoU/Freq_weighted'),
SparseConfusionMatrix(num_classes=3, name='cm')] \
+ SparseIoU.get_iou_metrics(num_classes=3, class_names=class_names))

# Create another checkpoint/log folder for model.name and timestamp
args.ckpt_dir = os.path.join(args.ckpt_dir,
model.name+'-'+args.file_suffix)
args.log_dir = os.path.join(args.log_dir, 'fit',
model.name+'-'+args.file_suffix)
if args.fold_num != 0: # If using five-fold cross-validation
args.ckpt_dir += f'_fold_{args.fold_num}'
args.log_dir += f'_fold_{args.fold_num}'

# Check if resume from training
initial_epoch = 0
if args.ckpt_filepath is not None:
if args.ckpt_weights_only:
if args.ckpt_filepath.endswith('.index'): # Get rid of the suffix
args.ckpt_filepath = args.ckpt_filepath.replace('.index', '')
model.load_weights(args.ckpt_filepath).assert_existing_objects_matched()
print('Model weights loaded')
else:
model = load_whole_model(args.ckpt_filepath)
print('Whole model (weights + optimizer state) loaded')

initial_epoch = int(args.ckpt_filepath.split('/')[-1]\
.split('-')[1])
# Save in same checkpoint_dir but different log_dir (add current time)
args.ckpt_dir = os.path.abspath(
os.path.dirname(args.ckpt_filepath))
args.log_dir = args.ckpt_dir.replace(
'checkpoints', 'tf_logs/fit') + f'-retrain_{args.file_suffix}'
#class_names = ['Background', 'Gray Matter', 'White Matter']
#model.compile(optimizer=optimizers.Adam(),
# loss=get_loss_func(args.loss_func, class_weight,
# gamma=args.focal_loss_gamma),
# metrics=[metrics.SparseCategoricalAccuracy(),
# SparseMeanIoU(num_classes=3, name='IoU/Mean'),
# SparsePixelAccuracy(num_classes=3, name='PixelAcc'),
# SparseMeanAccuracy(num_classes=3, name='MeanAcc'),
# SparseFreqIoU(num_classes=3, name='IoU/Freq_weighted'),
# SparseConfusionMatrix(num_classes=3, name='cm')] \
# + SparseIoU.get_iou_metrics(num_classes=3, class_names=class_names))

## Create another checkpoint/log folder for model.name and timestamp
#args.ckpt_dir = os.path.join(args.ckpt_dir,
# model.name+'-'+args.file_suffix)
#args.log_dir = os.path.join(args.log_dir, 'fit',
# model.name+'-'+args.file_suffix)
#if args.fold_num != 0: # If using five-fold cross-validation
# args.ckpt_dir += f'_fold_{args.fold_num}'
# args.log_dir += f'_fold_{args.fold_num}'

## Check if resume from training
#initial_epoch = 0
#if args.ckpt_filepath is not None:
# if args.ckpt_weights_only:
# if args.ckpt_filepath.endswith('.index'): # Get rid of the suffix
# args.ckpt_filepath = args.ckpt_filepath.replace('.index', '')
# model.load_weights(args.ckpt_filepath).assert_existing_objects_matched()
# print('Model weights loaded')
# else:
# model = load_whole_model(args.ckpt_filepath)
# print('Whole model (weights + optimizer state) loaded')

# initial_epoch = int(args.ckpt_filepath.split('/')[-1]\
# .split('-')[1])
# # Save in same checkpoint_dir but different log_dir (add current time)
# args.ckpt_dir = os.path.abspath(
# os.path.dirname(args.ckpt_filepath))
# args.log_dir = args.ckpt_dir.replace(
# 'checkpoints', 'tf_logs/fit') + f'-retrain_{args.file_suffix}'

# Write configurations to log_dir
log_configs(args.log_dir, save_svs_file, train_dataset, val_dataset, args)
log_configs(args.log_dir, args)

# Create checkpoint directory
if not os.path.exists(args.ckpt_dir):
os.makedirs(args.ckpt_dir)
# Create log directory
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)

# Create a callback that saves the model's weights every 1 epoch
if val_dataset:
ckpt_path = os.path.join(
args.ckpt_dir, 'cp-{epoch:03d}-{val_IoU/Mean:.4f}.ckpt')
else:
ckpt_path = os.path.join(
args.ckpt_dir, 'cp-{epoch:03d}-{IoU/Mean:.4f}.ckpt')
cp_callback = callbacks.ModelCheckpoint(
filepath=ckpt_path,
verbose=1,
save_weights_only=args.ckpt_weights_only,
save_freq='epoch')

# Create a TensorBoard callback
tb_callback = callbacks.TensorBoard(
log_dir=args.log_dir,
histogram_freq=1,
write_graph=True,
write_images=False,
update_freq='batch',
profile_batch='100, 120')

# Create a Lambda callback for plotting confusion matrix
cm_callback = get_cm_callback(args.log_dir, class_names)

# Create a TerminateOnNaN callback
nan_callback = callbacks.TerminateOnNaN()

# Create an EarlyStopping callback
if val_dataset:
es_callback = callbacks.EarlyStopping(monitor='val_IoU/Mean',
min_delta=0.01,
patience=3,
verbose=1,
mode='max')

if val_dataset:
model.fit(
train_dataset,
epochs=args.num_epochs,
steps_per_epoch=len(train_dataset) \
if args.steps_per_epoch == -1 else args.steps_per_epoch,
initial_epoch=initial_epoch,
validation_data=val_dataset,
validation_steps=len(val_dataset) // args.val_subsplits \
if args.val_steps == -1 else args.val_steps,
callbacks=[cp_callback, tb_callback, nan_callback, cm_callback, es_callback])
else:
model.fit(
train_dataset,
epochs=args.num_epochs,
steps_per_epoch=len(train_dataset) \
if args.steps_per_epoch == -1 else args.steps_per_epoch,
initial_epoch=initial_epoch,
callbacks=[cp_callback, tb_callback, nan_callback, cm_callback])
#if not os.path.exists(args.ckpt_dir):
# os.makedirs(args.ckpt_dir)
## Create log directory
#if not os.path.exists(args.log_dir):
# os.makedirs(args.log_dir)

## Create a callback that saves the model's weights every 1 epoch
#if val_dataset:
# ckpt_path = os.path.join(
# args.ckpt_dir, 'cp-{epoch:03d}-{val_IoU/Mean:.4f}.ckpt')
#else:
# ckpt_path = os.path.join(
# args.ckpt_dir, 'cp-{epoch:03d}-{IoU/Mean:.4f}.ckpt')
#cp_callback = callbacks.ModelCheckpoint(
# filepath=ckpt_path,
# verbose=1,
# save_weights_only=args.ckpt_weights_only,
# save_freq='epoch')

## Create a TensorBoard callback
#tb_callback = callbacks.TensorBoard(
# log_dir=args.log_dir,
# histogram_freq=1,
# write_graph=True,
# write_images=False,
# update_freq='batch',
# profile_batch='100, 120')

## Create a Lambda callback for plotting confusion matrix
#cm_callback = get_cm_callback(args.log_dir, class_names)

## Create a TerminateOnNaN callback
#nan_callback = callbacks.TerminateOnNaN()

## Create an EarlyStopping callback
#if val_dataset:
# es_callback = callbacks.EarlyStopping(monitor='val_IoU/Mean',
# min_delta=0.01,
# patience=3,
# verbose=1,
# mode='max')

#if val_dataset:
# model.fit(
# train_dataset,
# epochs=args.num_epochs,
# steps_per_epoch=len(train_dataset) \
# if args.steps_per_epoch == -1 else args.steps_per_epoch,
# initial_epoch=initial_epoch,
# validation_data=val_dataset,
# validation_steps=len(val_dataset) // args.val_subsplits \
# if args.val_steps == -1 else args.val_steps,
# callbacks=[cp_callback, tb_callback, nan_callback, cm_callback, es_callback])
#else:
# model.fit(
# train_dataset,
# epochs=args.num_epochs,
# steps_per_epoch=len(train_dataset) \
# if args.steps_per_epoch == -1 else args.steps_per_epoch,
# initial_epoch=initial_epoch,
# callbacks=[cp_callback, tb_callback, nan_callback, cm_callback])
# TODO: Switch to tf.data

print('Training finished!')
Expand Down

0 comments on commit 0b1d1ba

Please sign in to comment.