@@ -155,7 +155,11 @@ def vgg16(num_classes=1000, init_weights=False):
155
155
),
156
156
)
157
157
training_dataloader = torch .utils .data .DataLoader (
158
- training_dataset , batch_size = args .batch_size , shuffle = True , num_workers = 2
158
+ training_dataset ,
159
+ batch_size = args .batch_size ,
160
+ shuffle = True ,
161
+ num_workers = 2 ,
162
+ drop_last = True ,
159
163
)
160
164
161
165
data = iter (training_dataloader )
@@ -211,8 +215,12 @@ def calibrate_loop(model):
211
215
)
212
216
213
217
testing_dataloader = torch .utils .data .DataLoader (
214
- testing_dataset , batch_size = args .batch_size , shuffle = False , num_workers = 2
215
- )
218
+ testing_dataset ,
219
+ batch_size = args .batch_size ,
220
+ shuffle = False ,
221
+ num_workers = 2 ,
222
+ drop_last = True ,
223
+ ) # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`
216
224
217
225
with torch .no_grad ():
218
226
with export_torch_mode ():
@@ -235,10 +243,9 @@ def calibrate_loop(model):
235
243
loss = 0.0
236
244
class_probs = []
237
245
class_preds = []
238
- model .eval ()
239
246
for data , labels in testing_dataloader :
240
247
data , labels = data .cuda (), labels .cuda (non_blocking = True )
241
- out = model (data )
248
+ out = trt_model (data )
242
249
loss += crit (out , labels )
243
250
preds = torch .max (out , 1 )[1 ]
244
251
class_probs .append ([F .softmax (i , dim = 0 ) for i in out ])
0 commit comments