21
21
22
22
from runtime .losses import partial_losses
23
23
from runtime .parse_results import process_performance_stats
24
+ from model .tf_trt import export_model , TFTRTModel
24
25
25
26
26
27
def train (params , model , dataset , logger ):
@@ -101,6 +102,11 @@ def train_step(features, labels, warmup_batch=False):
101
102
break
102
103
if hvd .rank () == 0 :
103
104
checkpoint .save (file_prefix = os .path .join (params .model_dir , "checkpoint" ))
105
+ if params .use_savedmodel :
106
+ prec = 'amp' if params .use_amp else 'fp32'
107
+ model .save (os .path .join (params .model_dir , f'saved_model_{ prec } ' ))
108
+ if params .use_tftrt :
109
+ export_model (params .model_dir , prec , os .path .join (params .model_dir , f'tf-trt_model_{ prec } ' ))
104
110
105
111
logger .flush ()
106
112
@@ -110,9 +116,15 @@ def evaluate(params, model, dataset, logger, restore_checkpoint=True):
110
116
print ("No fold specified for evaluation. Please use --fold [int] to select a fold." )
111
117
ce_loss = tf .keras .metrics .Mean (name = 'ce_loss' )
112
118
f1_loss = tf .keras .metrics .Mean (name = 'dice_loss' )
113
- checkpoint = tf .train .Checkpoint (model = model )
114
119
if params .model_dir and restore_checkpoint :
115
- checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
120
+ prec = 'amp' if params .use_amp else 'fp32'
121
+ if params .use_savedmodel :
122
+ model = tf .keras .models .load_model (os .path .join (params .model_dir , f'saved_model_{ prec } ' ))
123
+ elif params .use_tftrt :
124
+ model = TFTRTModel (model_dir = params .model_dir , precision = prec )
125
+ else :
126
+ checkpoint = tf .train .Checkpoint (model = model )
127
+ checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
116
128
117
129
def validation_step (features , labels ):
118
130
output_map = model (features , training = False )
@@ -135,9 +147,15 @@ def validation_step(features, labels):
135
147
136
148
137
149
def predict (params , model , dataset , logger ):
138
- checkpoint = tf . train . Checkpoint ( model = model )
150
+ prec = 'amp' if params . use_amp else 'fp32'
139
151
if params .model_dir :
140
- checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
152
+ if params .use_savedmodel :
153
+ model = tf .keras .models .load_model (os .path .join (params .model_dir , f'saved_model_{ prec } ' ))
154
+ elif params .use_tftrt :
155
+ model = TFTRTModel (model_dir = params .model_dir , precision = prec )
156
+ else :
157
+ checkpoint = tf .train .Checkpoint (model = model )
158
+ checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
141
159
142
160
@tf .function
143
161
def prediction_step (features ):
0 commit comments