Skip to content

Commit 6f73a23

Browse files
authored
Merge pull request #1529 from gs-olive/perf_docs
feat: Add functionality to FX benchmarking + Improve documentation
2 parents 27733ba + 360f6c4 commit 6f73a23

File tree

4 files changed

+106
-20
lines changed

4 files changed

+106
-20
lines changed

tools/perf/README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ There are two sample configuration files added.
6666

6767
| Name | Supported Values | Description |
6868
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
69-
| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. |
69+
| backend | all, torchscript, fx2trt, torch, torch_tensorrt, tensorrt | Supported backends for inference. "all" implies the last four methods in the list at left, and "torchscript" implies the last three (excludes fx path) |
7070
| input | - | Input binding names. Expected to list shapes of each input bindings |
7171
| model | - | Configure the model filename and name |
72+
| model_torch | - | Name of torch model file and name (used for fx2trt) (optional) |
7273
| filename | - | Model file name to load from disk. |
7374
| name | - | Model name |
7475
| runtime | - | Runtime configurations |
@@ -83,6 +84,7 @@ backend:
8384
- torch
8485
- torch_tensorrt
8586
- tensorrt
87+
- fx2trt
8688
input:
8789
input0:
8890
- 3
@@ -92,6 +94,9 @@ input:
9294
model:
9395
filename: model.plan
9496
name: vgg16
97+
model_torch:
98+
filename: model_torch.pt
99+
name: vgg16
95100
runtime:
96101
device: 0
97102
precision:
@@ -108,8 +113,9 @@ Note:
108113

109114
Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module
110115

111-
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt, tensorrt or fx2trt
116+
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt,fx2trt
112117
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
118+
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if fx2trt is a chosen backend)
113119
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
114120
* `--batch_size` : Batch size
115121
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
@@ -122,9 +128,10 @@ Eg:
122128

123129
```
124130
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
131+
--model_torch ${MODELS_DIR}/vgg16_torch.pt \
125132
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
126133
--batch_size 1 \
127-
--backends torch,torch_tensorrt,tensorrt \
134+
--backends torch,torch_tensorrt,tensorrt,fx2trt \
128135
--report "vgg_perf_bs1.txt"
129136
```
130137

tools/perf/hub.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626

2727
# Key models selected for benchmarking with their respective paths
2828
BENCHMARK_MODELS = {
29-
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
29+
"vgg16": {
30+
"model": models.vgg16(weights=models.VGG16_Weights.DEFAULT),
31+
"path": ["script", "pytorch"],
32+
},
3033
"resnet50": {
3134
"model": models.resnet50(weights=None),
3235
"path": ["script", "pytorch"],

tools/perf/perf_run.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,20 @@ def run(
292292
print("int8 precision expects calibration cache file for inference")
293293
return False
294294

295+
if (model is None) and (backend != "fx2trt"):
296+
warnings.warn(
297+
f"Requested backend {backend} without specifying a TorchScript Model, "
298+
+ "skipping this backend"
299+
)
300+
continue
301+
302+
if (model_torch is None) and (backend in ("all", "fx2trt")):
303+
warnings.warn(
304+
f"Requested backend {backend} without specifying a PyTorch Model, "
305+
+ "skipping this backend"
306+
)
307+
continue
308+
295309
if backend == "all":
296310
run_torch(model, input_tensors, params, precision, batch_size)
297311
run_torch_tensorrt(
@@ -311,6 +325,27 @@ def run(
311325
is_trt_engine,
312326
batch_size,
313327
)
328+
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)
329+
330+
elif backend == "torchscript":
331+
run_torch(model, input_tensors, params, precision, batch_size)
332+
run_torch_tensorrt(
333+
model,
334+
input_tensors,
335+
params,
336+
precision,
337+
truncate_long_and_double,
338+
batch_size,
339+
)
340+
run_tensorrt(
341+
model,
342+
input_tensors,
343+
params,
344+
precision,
345+
truncate_long_and_double,
346+
is_trt_engine,
347+
batch_size,
348+
)
314349

315350
elif backend == "torch":
316351
run_torch(model, input_tensors, params, precision, batch_size)
@@ -326,12 +361,6 @@ def run(
326361
)
327362

328363
elif backend == "fx2trt":
329-
if model_torch is None:
330-
warnings.warn(
331-
"Requested backend fx2trt without specifying a PyTorch Model, "
332-
+ "skipping this backend"
333-
)
334-
continue
335364
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)
336365

337366
elif backend == "tensorrt":
@@ -371,9 +400,14 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None)
371400
results.append(stats)
372401

373402

374-
def load_model(params):
403+
def load_ts_model(params):
375404
model = None
376405
is_trt_engine = False
406+
407+
# No TorchScript Model Specified
408+
if len(params.get("model", "")) == 0:
409+
return None, None, is_trt_engine
410+
377411
# Load torch model traced/scripted
378412
model_file = params.get("model").get("filename")
379413
try:
@@ -393,6 +427,26 @@ def load_model(params):
393427
return model, model_name, is_trt_engine
394428

395429

430+
def load_torch_model(params):
431+
model = None
432+
433+
# No Torch Model Specified
434+
if len(params.get("model_torch", "")) == 0:
435+
return None, None
436+
437+
# Load torch model
438+
model_file = params.get("model_torch").get("filename")
439+
try:
440+
model_name = params.get("model_torch").get("name")
441+
except:
442+
model_name = model_file
443+
444+
print("Loading Torch model: ", model_file)
445+
model = torch.load(model_file).cuda()
446+
447+
return model, model_name
448+
449+
396450
if __name__ == "__main__":
397451
arg_parser = argparse.ArgumentParser(
398452
description="Run inference on a model with random input values"
@@ -408,7 +462,9 @@ def load_model(params):
408462
type=str,
409463
help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt",
410464
)
411-
arg_parser.add_argument("--model", type=str, help="Name of torchscript model file")
465+
arg_parser.add_argument(
466+
"--model", type=str, default="", help="Name of torchscript model file"
467+
)
412468
arg_parser.add_argument(
413469
"--model_torch",
414470
type=str,
@@ -458,7 +514,16 @@ def load_model(params):
458514
parser = ConfigParser(args.config)
459515
# Load YAML params
460516
params = parser.read_config()
461-
model, model_name, is_trt_engine = load_model(params)
517+
model, model_name, is_trt_engine = load_ts_model(params)
518+
model_torch, model_name_torch = load_torch_model(params)
519+
520+
# If neither model type was provided
521+
if (model is None) and (model_torch is None):
522+
raise ValueError(
523+
"No valid models specified. Please provide a torchscript model file or model name "
524+
+ "(among the following options vgg16|resnet50|efficientnet_b0|vit) "
525+
+ "or provide a torch model file"
526+
)
462527

463528
# Default device is set to 0. Configurable using yaml config file.
464529
torch.cuda.set_device(params.get("runtime").get("device", 0))
@@ -489,7 +554,10 @@ def load_model(params):
489554

490555
if not is_trt_engine and (precision == "fp16" or precision == "half"):
491556
# If model is TensorRT serialized engine then model.half will report failure
492-
model = model.half()
557+
if model is not None:
558+
model = model.half()
559+
if model_torch is not None:
560+
model_torch = model_torch.half()
493561

494562
backends = params.get("backend")
495563
# Run inference
@@ -502,6 +570,7 @@ def load_model(params):
502570
truncate_long_and_double,
503571
batch_size,
504572
is_trt_engine,
573+
model_torch,
505574
)
506575
else:
507576
params = vars(args)
@@ -511,23 +580,27 @@ def load_model(params):
511580
model_name_torch = params["model_torch"]
512581
model_torch = None
513582

514-
# Load TorchScript model
583+
# Load TorchScript model, if provided
515584
if os.path.exists(model_name):
516585
print("Loading user provided torchscript model: ", model_name)
517586
model = torch.jit.load(model_name).cuda().eval()
518587
elif model_name in BENCHMARK_MODELS:
519588
print("Loading torchscript model from BENCHMARK_MODELS for: ", model_name)
520589
model = BENCHMARK_MODELS[model_name]["model"].eval().cuda()
521-
else:
522-
raise ValueError(
523-
"Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)"
524-
)
525590

526591
# Load PyTorch Model, if provided
527592
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
528593
print("Loading user provided torch model: ", model_name_torch)
529594
model_torch = torch.load(model_name_torch).eval().cuda()
530595

596+
# If neither model type was provided
597+
if (model is None) and (model_torch is None):
598+
raise ValueError(
599+
"No valid models specified. Please provide a torchscript model file or model name "
600+
+ "(among the following options vgg16|resnet50|efficientnet_b0|vit) "
601+
+ "or provide a torch model file"
602+
)
603+
531604
backends = parse_backends(params["backends"])
532605
truncate_long_and_double = params["truncate"]
533606
batch_size = params["batch_size"]

tools/perf/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import timm
66

77
BENCHMARK_MODELS = {
8-
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
8+
"vgg16": {
9+
"model": models.vgg16(weights=models.VGG16_Weights.DEFAULT),
10+
"path": ["script", "pytorch"],
11+
},
912
"resnet50": {
1013
"model": models.resnet50(weights=None),
1114
"path": ["script", "pytorch"],

0 commit comments

Comments
 (0)