27
27
from utils import dist_print
28
28
from tensorflow .python .saved_model import signature_constants
29
29
from tensorflow .python .framework import convert_to_constants
30
+ from tensorflow .python .keras .saving .saving_utils import model_input_signature
31
+ from collections import OrderedDict
30
32
31
33
try :
32
34
from tensorflow_dot_based_interact .python .ops import dot_based_interact_ops
@@ -43,6 +45,21 @@ def scale_grad(grad, factor):
43
45
# dense gradient
44
46
return grad * factor
45
47
48
+ def _create_inputs_dict (numerical_features , categorical_features ):
49
+ # Passing inputs as (numerical_features, categorical_features) changes the model
50
+ # input signature to (<tensor, [list of tensors]>).
51
+ # This leads to errors while loading the saved model.
52
+ # TF flattens the inputs while loading the model,
53
+ # so the inputs are converted from (<tensor, [list of tensors]>) -> [list of tensors]
54
+ # see _set_inputs function in training_v1.py:
55
+ # https://github.com/tensorflow/tensorflow/blob/7628750678786f1b65e8905fb9406d8fbffef0db/tensorflow/python/keras/engine/training_v1.py#L2588)
56
+ inputs = OrderedDict ()
57
+ inputs ['numerical_features' ] = numerical_features
58
+
59
+ if categorical_features != - 1 :
60
+ for count , c_feature in enumerate (categorical_features ):
61
+ inputs ["categorical_features_" + str (count )] = c_feature
62
+ return inputs
46
63
47
64
class DataParallelSplitter :
48
65
def __init__ (self , batch_size ):
@@ -104,9 +121,9 @@ def _top_part_weight_update(self, unscaled_gradients):
104
121
def train_step (self , numerical_features , categorical_features , labels ):
105
122
self .lr_scheduler ()
106
123
124
+ inputs = _create_inputs_dict (numerical_features , categorical_features )
107
125
with tf .GradientTape () as tape :
108
- predictions = self .dlrm (inputs = (numerical_features , categorical_features ),
109
- training = True )
126
+ predictions = self .dlrm (inputs = inputs , training = True )
110
127
111
128
unscaled_loss = self .bce (labels , predictions )
112
129
# tf keras doesn't reduce the loss when using a Custom Training Loop
@@ -132,7 +149,6 @@ def train_step(self, numerical_features, categorical_features, labels):
132
149
return mean_loss
133
150
134
151
135
-
136
152
def evaluate (validation_pipeline , dlrm , timer , auc_thresholds ,
137
153
data_parallel_splitter , max_steps = None , cast_dtype = None ):
138
154
@@ -167,7 +183,8 @@ def evaluate(validation_pipeline, dlrm, timer, auc_thresholds,
167
183
if max_steps is not None and eval_step >= max_steps :
168
184
break
169
185
170
- y_pred = dlrm ((numerical_features , categorical_features ), False )
186
+ inputs = _create_inputs_dict (numerical_features , categorical_features )
187
+ y_pred = dlrm (inputs , False )
171
188
end = time .time ()
172
189
latency = end - begin
173
190
latencies .append (latency )
@@ -376,7 +393,8 @@ def force_initialization(self):
376
393
377
394
@tf .function
378
395
def call (self , inputs , sigmoid = False ):
379
- numerical_features , cat_features = inputs
396
+ vals = list (inputs .values ())
397
+ numerical_features , cat_features = vals [0 ], vals [1 :]
380
398
embedding_outputs = self ._call_embeddings (cat_features )
381
399
382
400
if self .running_bottom_mlp :
@@ -583,14 +601,25 @@ def restore_checkpoint_if_path_exists(self, checkpoint_path):
583
601
dist_print ('Restored a checkpoint from' , checkpoint_path )
584
602
return self
585
603
586
- def save_model_if_path_exists (self , path ):
604
+ def save_model_if_path_exists (self , path , save_input_signature = False ):
587
605
if not path :
588
606
return
589
607
590
608
if hvd .size () > 1 :
591
609
raise ValueError ('SavedModel conversion not supported in HybridParallel mode' )
592
610
593
- tf .keras .models .save_model (model = self , filepath = path , overwrite = True )
611
+ if save_input_signature :
612
+ input_sig = model_input_signature (self , keep_original_batch_size = True )
613
+ call_graph = tf .function (self )
614
+ signatures = call_graph .get_concrete_function (input_sig [0 ])
615
+ else :
616
+ signatures = None
617
+
618
+ tf .keras .models .save_model (
619
+ model = self ,
620
+ filepath = path ,
621
+ overwrite = True ,
622
+ signatures = signatures )
594
623
595
624
def load_model_if_path_exists (self , path ):
596
625
if not path :
0 commit comments