Skip to content

Commit ddbcd54

Browse files
mmarcinkiewicznv-kkudrynski
authored andcommitted
[UNet/TF2] Add tf-trt and SavedModel tests. Remove profiling tests.
1 parent 7ce1754 commit ddbcd54

File tree

6 files changed

+98
-13
lines changed

6 files changed

+98
-13
lines changed

TensorFlow2/Segmentation/UNet_Medical/data_loading/data_loader.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
class Dataset:
2626
"""Load, separate and prepare the data for training and prediction"""
2727

28-
def __init__(self, data_dir, batch_size, fold, augment=False, gpu_id=0, num_gpus=1, seed=0):
28+
def __init__(self, data_dir, batch_size, fold, augment=False, gpu_id=0, num_gpus=1, seed=0, amp=False):
2929
if not os.path.exists(data_dir):
3030
raise FileNotFoundError('Cannot find data dir: {}'.format(data_dir))
3131
self._data_dir = data_dir
3232
self._batch_size = batch_size
3333
self._augment = augment
34+
self.precision = tf.float16 if amp else tf.float32
3435

3536
self._seed = seed
3637

@@ -149,7 +150,7 @@ def _preproc_samples(self, inputs, labels, augment=True):
149150
cond = tf.less(labels, 0.5 * tf.ones(tf.shape(input=labels)))
150151
labels = tf.where(cond, tf.zeros(tf.shape(input=labels)), tf.ones(tf.shape(input=labels)))
151152

152-
return inputs, labels
153+
return tf.cast(inputs, self.precision), labels
153154

154155
@tf.function
155156
def _preproc_eval_samples(self, inputs, labels):
@@ -162,7 +163,12 @@ def _preproc_eval_samples(self, inputs, labels):
162163
cond = tf.less(labels, 0.5 * tf.ones(tf.shape(input=labels)))
163164
labels = tf.where(cond, tf.zeros(tf.shape(input=labels)), tf.ones(tf.shape(input=labels)))
164165

165-
return (inputs, labels)
166+
return tf.cast(inputs, self.precision), labels
167+
168+
@tf.function
169+
def _preproc_test_samples(self, inputs):
170+
inputs = self._normalize_inputs(inputs)
171+
return tf.cast(inputs, self.precision)
166172

167173
def train_fn(self, drop_remainder=False):
168174
"""Input function for training"""
@@ -195,7 +201,7 @@ def test_fn(self, count, drop_remainder=False):
195201
dataset = tf.data.Dataset.from_tensor_slices(
196202
self._test_images)
197203
dataset = dataset.repeat(count=count)
198-
dataset = dataset.map(self._normalize_inputs)
204+
dataset = dataset.map(self._preproc_test_samples)
199205
dataset = dataset.batch(self._batch_size, drop_remainder=drop_remainder)
200206
dataset = dataset.prefetch(self._batch_size)
201207

TensorFlow2/Segmentation/UNet_Medical/main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def main():
5151
augment=params.augment,
5252
gpu_id=hvd.rank(),
5353
num_gpus=hvd.size(),
54-
seed=params.seed)
54+
seed=params.seed,
55+
amp=params.use_amp)
5556

5657
if 'train' in params.exec_mode:
5758
train(params, model, dataset, logger)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from operator import itemgetter
3+
4+
import tensorflow as tf
5+
from tensorflow.python.compiler.tensorrt import trt_convert as trt
6+
from tensorflow.compat.v1.saved_model import tag_constants, signature_constants
7+
8+
9+
def export_model(model_dir, prec, tf_trt_model_dir=None):
10+
model = tf.keras.models.load_model(os.path.join(model_dir, f'saved_model_{prec}'))
11+
input_shape = [1, 572, 572, 1]
12+
dummy_input = tf.constant(tf.zeros(input_shape, dtype=tf.float32 if prec=="fp32" else tf.float16))
13+
_ = model(dummy_input, training=False)
14+
15+
trt_prec = trt.TrtPrecisionMode.FP32 if prec == "fp32" else trt.TrtPrecisionMode.FP16
16+
converter = trt.TrtGraphConverterV2(
17+
input_saved_model_dir=os.path.join(model_dir, f'saved_model_{prec}'),
18+
conversion_params=trt.TrtConversionParams(precision_mode=trt_prec),
19+
)
20+
converter.convert()
21+
tf_trt_model_dir = tf_trt_model_dir or f'/tmp/tf-trt_model_{prec}'
22+
converter.save(tf_trt_model_dir)
23+
print(f"TF-TRT model saved at {tf_trt_model_dir}")
24+
25+
26+
def _force_gpu_resync(func):
27+
p = tf.constant(0.) # Create small tensor to force GPU resync
28+
29+
def wrapper(*args, **kwargs):
30+
rslt = func(*args, **kwargs)
31+
(p + 1.).numpy() # Sync the GPU
32+
return rslt
33+
34+
return wrapper
35+
36+
37+
class TFTRTModel:
38+
def __init__(self, model_dir, precision, output_tensor_name="output_1"):
39+
temp_tftrt_dir = f"/tmp/tf-trt_model_{precision}"
40+
export_model(model_dir, precision, temp_tftrt_dir)
41+
saved_model_loaded = tf.saved_model.load(temp_tftrt_dir, tags=[tag_constants.SERVING])
42+
print(f"TF-TRT model loaded from {temp_tftrt_dir}")
43+
self.graph_func = saved_model_loaded.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
44+
self.output_tensor_name = output_tensor_name
45+
self.precision = tf.float16 if precision == "amp" else tf.float32
46+
47+
def __call__(self, x, **kwargs):
48+
return self.infer_step(x)
49+
50+
#@_force_gpu_resync
51+
@tf.function(jit_compile=False)
52+
def infer_step(self, batch_x):
53+
if batch_x.dtype != self.precision:
54+
batch_x = tf.cast(batch_x, self.precision)
55+
output = self.graph_func(batch_x)
56+
return itemgetter(self.output_tensor_name)(output)

TensorFlow2/Segmentation/UNet_Medical/runtime/arguments.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,12 @@
100100
PARSER.add_argument('--use_xla', '--xla', dest='use_xla', action='store_true',
101101
help="""Train using XLA""")
102102

103-
PARSER.add_argument('--use_trt', dest='use_trt', action='store_true',
103+
PARSER.add_argument('--use_tftrt', dest='use_tftrt', action='store_true',
104104
help="""Use TF-TRT""")
105105

106+
PARSER.add_argument('--use_savedmodel', dest='use_savedmodel', action='store_true',
107+
help="""Use SavedModel""")
108+
106109
PARSER.add_argument('--resume_training', dest='resume_training', action='store_true',
107110
help="""Resume training from a checkpoint""")
108111

@@ -125,7 +128,8 @@ def parse_args(flags):
125128
'benchmark': flags.benchmark,
126129
'seed': flags.seed,
127130
'use_amp': flags.use_amp,
128-
'use_trt': flags.use_trt,
131+
'use_tftrt': flags.use_tftrt,
132+
'use_savedmodel': flags.use_savedmodel,
129133
'use_xla': flags.use_xla,
130134
'resume_training': flags.resume_training,
131135
})

TensorFlow2/Segmentation/UNet_Medical/runtime/run.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from runtime.losses import partial_losses
2323
from runtime.parse_results import process_performance_stats
24+
from model.tf_trt import export_model, TFTRTModel
2425

2526

2627
def train(params, model, dataset, logger):
@@ -101,6 +102,11 @@ def train_step(features, labels, warmup_batch=False):
101102
break
102103
if hvd.rank() == 0:
103104
checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint"))
105+
if params.use_savedmodel:
106+
prec = 'amp' if params.use_amp else 'fp32'
107+
model.save(os.path.join(params.model_dir, f'saved_model_{prec}'))
108+
if params.use_tftrt:
109+
export_model(params.model_dir, prec, os.path.join(params.model_dir, f'tf-trt_model_{prec}'))
104110

105111
logger.flush()
106112

@@ -110,9 +116,15 @@ def evaluate(params, model, dataset, logger, restore_checkpoint=True):
110116
print("No fold specified for evaluation. Please use --fold [int] to select a fold.")
111117
ce_loss = tf.keras.metrics.Mean(name='ce_loss')
112118
f1_loss = tf.keras.metrics.Mean(name='dice_loss')
113-
checkpoint = tf.train.Checkpoint(model=model)
114119
if params.model_dir and restore_checkpoint:
115-
checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
120+
prec = 'amp' if params.use_amp else 'fp32'
121+
if params.use_savedmodel:
122+
model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}'))
123+
elif params.use_tftrt:
124+
model = TFTRTModel(model_dir=params.model_dir, precision=prec)
125+
else:
126+
checkpoint = tf.train.Checkpoint(model=model)
127+
checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
116128

117129
def validation_step(features, labels):
118130
output_map = model(features, training=False)
@@ -135,9 +147,15 @@ def validation_step(features, labels):
135147

136148

137149
def predict(params, model, dataset, logger):
138-
checkpoint = tf.train.Checkpoint(model=model)
150+
prec = 'amp' if params.use_amp else 'fp32'
139151
if params.model_dir:
140-
checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
152+
if params.use_savedmodel:
153+
model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}'))
154+
elif params.use_tftrt:
155+
model = TFTRTModel(model_dir=params.model_dir, precision=prec)
156+
else:
157+
checkpoint = tf.train.Checkpoint(model=model)
158+
checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
141159

142160
@tf.function
143161
def prediction_step(features):

TensorFlow2/Segmentation/UNet_Medical/runtime/setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def set_flags(params):
5757

5858

5959
def prepare_model_dir(params):
60-
model_dir = os.path.join(params.model_dir, "model_checkpoint")
61-
model_dir = model_dir if (hvd.rank() == 0 and not params.benchmark) else None
60+
# model_dir = os.path.join(params.model_dir, "model_checkpoint")
61+
model_dir = params.model_dir if (hvd.rank() == 0 and not params.benchmark) else None
6262
if model_dir is not None:
6363
os.makedirs(model_dir, exist_ok=True)
6464
if ('train' in params.exec_mode) and (not params.resume_training):

0 commit comments

Comments
 (0)