Skip to content

Commit f909b07

Browse files
authored
fix: bug in vgg16_fp8_ptq example (#2950)
1 parent 9b5252d commit f909b07

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

examples/dynamo/vgg16_fp8_ptq.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,11 @@ def vgg16(num_classes=1000, init_weights=False):
155155
),
156156
)
157157
training_dataloader = torch.utils.data.DataLoader(
158-
training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2
158+
training_dataset,
159+
batch_size=args.batch_size,
160+
shuffle=True,
161+
num_workers=2,
162+
drop_last=True,
159163
)
160164

161165
data = iter(training_dataloader)
@@ -211,8 +215,12 @@ def calibrate_loop(model):
211215
)
212216

213217
testing_dataloader = torch.utils.data.DataLoader(
214-
testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
215-
)
218+
testing_dataset,
219+
batch_size=args.batch_size,
220+
shuffle=False,
221+
num_workers=2,
222+
drop_last=True,
223+
) # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`
216224

217225
with torch.no_grad():
218226
with export_torch_mode():
@@ -235,10 +243,9 @@ def calibrate_loop(model):
235243
loss = 0.0
236244
class_probs = []
237245
class_preds = []
238-
model.eval()
239246
for data, labels in testing_dataloader:
240247
data, labels = data.cuda(), labels.cuda(non_blocking=True)
241-
out = model(data)
248+
out = trt_model(data)
242249
loss += crit(out, labels)
243250
preds = torch.max(out, 1)[1]
244251
class_probs.append([F.softmax(i, dim=0) for i in out])

0 commit comments

Comments
 (0)