1
1
#!/usr/bin/python3
2
2
3
- # Copyright (c) 2019 , NVIDIA CORPORATION. All rights reserved.
3
+ # Copyright (c) 2020 , NVIDIA CORPORATION. All rights reserved.
4
4
#
5
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
6
# you may not use this file except in compliance with the License.
15
15
# limitations under the License.
16
16
17
17
import os
18
+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
19
+
18
20
from functools import partial
19
21
import json
20
22
import logging
21
23
from argparse import ArgumentParser
24
+
22
25
import tensorflow as tf
26
+ tf .logging .set_verbosity (tf .logging .ERROR )
27
+
23
28
import numpy as np
24
29
import horovod .tensorflow as hvd
30
+ from mpi4py import MPI
25
31
import dllogger
32
+ import time
26
33
27
34
from vae .utils .round import round_8
28
35
from vae .metrics .recall import recall
32
39
33
40
def main ():
34
41
hvd .init ()
42
+ mpi_comm = MPI .COMM_WORLD
35
43
36
44
parser = ArgumentParser (description = "Train a Variational Autoencoder for Collaborative Filtering in TensorFlow" )
37
45
parser .add_argument ('--train' , action = 'store_true' ,
38
46
help = 'Run training of VAE' )
39
47
parser .add_argument ('--test' , action = 'store_true' ,
40
48
help = 'Run validation of VAE' )
41
- parser .add_argument ('--inference' , action = 'store_true' ,
42
- help = 'Run inference on a single random example.'
43
- 'This can also be used to measure the latency for a batch size of 1' )
44
49
parser .add_argument ('--inference_benchmark' , action = 'store_true' ,
45
- help = 'Benchmark the inference throughput on a very large batch size ' )
46
- parser .add_argument ('--use_tf_amp ' , action = 'store_true' ,
50
+ help = 'Measure inference latency and throughput on a variety of batch sizes ' )
51
+ parser .add_argument ('--amp ' , action = 'store_true' , default = False ,
47
52
help = 'Enable Automatic Mixed Precision' )
48
53
parser .add_argument ('--epochs' , type = int , default = 400 ,
49
54
help = 'Number of epochs to train' )
@@ -85,6 +90,7 @@ def main():
85
90
default = None ,
86
91
help = 'Path for saving a checkpoint after the training' )
87
92
args = parser .parse_args ()
93
+ args .world_size = hvd .size ()
88
94
89
95
if args .batch_size_train % hvd .size () != 0 :
90
96
raise ValueError ('Global batch size should be a multiple of the number of workers' )
@@ -101,16 +107,27 @@ def main():
101
107
dllogger .init (backends = [])
102
108
logger .setLevel (logging .ERROR )
103
109
104
- dllogger .log (data = vars (args ), step = 'PARAMETER' )
110
+ if args .seed is None :
111
+ if hvd .rank () == 0 :
112
+ seed = int (time .time ())
113
+ else :
114
+ seed = None
105
115
106
- np .random .seed (args .seed )
107
- tf .set_random_seed (args .seed )
116
+ seed = mpi_comm .bcast (seed , root = 0 )
117
+ else :
118
+ seed = args .seed
119
+
120
+ tf .random .set_random_seed (seed )
121
+ np .random .seed (seed )
122
+ args .seed = seed
123
+
124
+ dllogger .log (data = vars (args ), step = 'PARAMETER' )
108
125
109
126
# Suppress TF warnings
110
127
os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '2'
111
128
112
129
# set AMP
113
- os .environ ['TF_ENABLE_AUTO_MIXED_PRECISION' ] = '1' if args .use_tf_amp else '0'
130
+ os .environ ['TF_ENABLE_AUTO_MIXED_PRECISION' ] = '1' if args .amp else '0'
114
131
115
132
# load dataset
116
133
(train_data ,
@@ -159,21 +176,36 @@ def main():
159
176
elif args .test and hvd .size () > 1 :
160
177
print ("Testing is not supported with horovod multigpu yet" )
161
178
162
- if args .inference_benchmark and hvd .size () <= 1 :
163
- # use the train data to get accurate throughput numbers for inference
164
- # the test and validation sets are too small to measure this accurately
165
- # vae.inference_benchmark()
166
- _ = vae .test (test_data_input = train_data ,
167
- test_data_true = train_data , metrics = {})
168
-
169
-
170
179
elif args .test and hvd .size () > 1 :
171
180
print ("Testing is not supported with horovod multigpu yet" )
172
181
173
- if args .inference :
174
- input_data = np .random .randint (low = 0 , high = 10000 , size = 10 )
175
- recommendations = vae .query (input_data = input_data )
176
- print ('Recommended item indices: ' , recommendations )
182
+ if args .inference_benchmark :
183
+ items_per_user = 10
184
+ item_indices = np .random .randint (low = 0 , high = 10000 , size = items_per_user )
185
+ user_indices = np .zeros (len (item_indices ))
186
+ indices = np .stack ([user_indices , item_indices ], axis = 1 )
187
+
188
+ num_batches = 200
189
+ latencies = []
190
+ for i in range (num_batches ):
191
+ start_time = time .time ()
192
+ _ = vae .query (indices = indices )
193
+ end_time = time .time ()
194
+
195
+ if i < 10 :
196
+ #warmup steps
197
+ continue
198
+
199
+ latencies .append (end_time - start_time )
200
+
201
+ result_data = {}
202
+ result_data [f'batch_1_mean_throughput' ] = 1 / np .mean (latencies )
203
+ result_data [f'batch_1_mean_latency' ] = np .mean (latencies )
204
+ result_data [f'batch_1_p90_latency' ] = np .percentile (latencies , 90 )
205
+ result_data [f'batch_1_p95_latency' ] = np .percentile (latencies , 95 )
206
+ result_data [f'batch_1_p99_latency' ] = np .percentile (latencies , 99 )
207
+
208
+ dllogger .log (data = result_data , step = tuple ())
177
209
178
210
vae .close_session ()
179
211
dllogger .flush ()
0 commit comments