Skip to content

Commit 81b9010

Browse files
nvgarvitknv-kkudrynski
authored andcommitted
[DLRM TF2] Saved model improvements
1 parent 828e88d commit 81b9010

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

TensorFlow2/Recommendation/DLRM/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def define_command_line_flags():
4747

4848
flags.DEFINE_string("saved_model_output_path", default=None,
4949
help='Path for storing the model in TensorFlow SavedModel format')
50+
flags.DEFINE_bool("save_input_signature", default=False,
51+
help="Save input signature in the SavedModel")
5052
flags.DEFINE_string("saved_model_input_path", default=None,
5153
help='Path for loading the model in TensorFlow SavedModel format')
5254

@@ -334,7 +336,8 @@ def main(argv):
334336

335337
elapsed = time.time() - train_begin
336338
dlrm.save_checkpoint_if_path_exists(FLAGS.save_checkpoint_path)
337-
dlrm.save_model_if_path_exists(FLAGS.saved_model_output_path)
339+
dlrm.save_model_if_path_exists(FLAGS.saved_model_output_path,
340+
save_input_signature=FLAGS.save_input_signature)
338341

339342
if hvd.rank() == 0:
340343
dist_print(f'Training run completed, elapsed: {elapsed:.0f} [s]')

TensorFlow2/Recommendation/DLRM/model.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from utils import dist_print
2828
from tensorflow.python.saved_model import signature_constants
2929
from tensorflow.python.framework import convert_to_constants
30+
from tensorflow.python.keras.saving.saving_utils import model_input_signature
31+
from collections import OrderedDict
3032

3133
try:
3234
from tensorflow_dot_based_interact.python.ops import dot_based_interact_ops
@@ -43,6 +45,21 @@ def scale_grad(grad, factor):
4345
# dense gradient
4446
return grad * factor
4547

48+
def _create_inputs_dict(numerical_features, categorical_features):
49+
# Passing inputs as (numerical_features, categorical_features) changes the model
50+
# input signature to (<tensor, [list of tensors]>).
51+
# This leads to errors while loading the saved model.
52+
# TF flattens the inputs while loading the model,
53+
# so the inputs are converted from (<tensor, [list of tensors]>) -> [list of tensors]
54+
# see _set_inputs function in training_v1.py:
55+
# https://github.com/tensorflow/tensorflow/blob/7628750678786f1b65e8905fb9406d8fbffef0db/tensorflow/python/keras/engine/training_v1.py#L2588)
56+
inputs = OrderedDict()
57+
inputs['numerical_features'] = numerical_features
58+
59+
if categorical_features != -1:
60+
for count, c_feature in enumerate(categorical_features):
61+
inputs["categorical_features_" + str(count)] = c_feature
62+
return inputs
4663

4764
class DataParallelSplitter:
4865
def __init__(self, batch_size):
@@ -104,9 +121,9 @@ def _top_part_weight_update(self, unscaled_gradients):
104121
def train_step(self, numerical_features, categorical_features, labels):
105122
self.lr_scheduler()
106123

124+
inputs = _create_inputs_dict(numerical_features, categorical_features)
107125
with tf.GradientTape() as tape:
108-
predictions = self.dlrm(inputs=(numerical_features, categorical_features),
109-
training=True)
126+
predictions = self.dlrm(inputs=inputs, training=True)
110127

111128
unscaled_loss = self.bce(labels, predictions)
112129
# tf keras doesn't reduce the loss when using a Custom Training Loop
@@ -132,7 +149,6 @@ def train_step(self, numerical_features, categorical_features, labels):
132149
return mean_loss
133150

134151

135-
136152
def evaluate(validation_pipeline, dlrm, timer, auc_thresholds,
137153
data_parallel_splitter, max_steps=None, cast_dtype=None):
138154

@@ -167,7 +183,8 @@ def evaluate(validation_pipeline, dlrm, timer, auc_thresholds,
167183
if max_steps is not None and eval_step >= max_steps:
168184
break
169185

170-
y_pred = dlrm((numerical_features, categorical_features), False)
186+
inputs = _create_inputs_dict(numerical_features, categorical_features)
187+
y_pred = dlrm(inputs, False)
171188
end = time.time()
172189
latency = end - begin
173190
latencies.append(latency)
@@ -376,7 +393,8 @@ def force_initialization(self):
376393

377394
@tf.function
378395
def call(self, inputs, sigmoid=False):
379-
numerical_features, cat_features = inputs
396+
vals = list(inputs.values())
397+
numerical_features, cat_features = vals[0], vals[1:]
380398
embedding_outputs = self._call_embeddings(cat_features)
381399

382400
if self.running_bottom_mlp:
@@ -583,14 +601,25 @@ def restore_checkpoint_if_path_exists(self, checkpoint_path):
583601
dist_print('Restored a checkpoint from', checkpoint_path)
584602
return self
585603

586-
def save_model_if_path_exists(self, path):
604+
def save_model_if_path_exists(self, path, save_input_signature=False):
587605
if not path:
588606
return
589607

590608
if hvd.size() > 1:
591609
raise ValueError('SavedModel conversion not supported in HybridParallel mode')
592610

593-
tf.keras.models.save_model(model=self, filepath=path, overwrite=True)
611+
if save_input_signature:
612+
input_sig = model_input_signature(self, keep_original_batch_size=True)
613+
call_graph = tf.function(self)
614+
signatures = call_graph.get_concrete_function(input_sig[0])
615+
else:
616+
signatures = None
617+
618+
tf.keras.models.save_model(
619+
model=self,
620+
filepath=path,
621+
overwrite=True,
622+
signatures=signatures)
594623

595624
def load_model_if_path_exists(self, path):
596625
if not path:

0 commit comments

Comments
 (0)