Skip to content

Commit 757b79f

Browse files
committed
feat: Add example usage scripts for dynamo path
- Add sample scripts covering resnet18, transformers, and custom examples showcasing the `torch_tensorrt.dynamo.compile` path, which can compile models with data-dependent control flow and other such restrictions which can make other compilation methods more difficult - Cover different customizeable features allowed in the new backend - Make scripts Sphinx-Gallery compatible Python files
1 parent e7f4752 commit 757b79f

8 files changed

+309
-10
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ docsrc/_build
3232
docsrc/_notebooks
3333
docsrc/_cpp_api
3434
docsrc/_tmp
35+
docsrc/tutorials/_rendered_examples
3536
*.so
3637
__pycache__
3738
*.egg-info
@@ -66,4 +67,4 @@ bazel-tensorrt
6667
*.cache
6768
*cifar-10-batches-py*
6869
bazel-project
69-
build/
70+
build/

docsrc/conf.py

+7
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"sphinx.ext.coverage",
4848
"sphinx.ext.mathjax",
4949
"sphinx.ext.viewcode",
50+
"sphinx_gallery.gen_gallery",
5051
]
5152

5253
napoleon_use_ivar = True
@@ -79,6 +80,12 @@
7980
# so a file named "default.css" will overwrite the builtin "default.css".
8081
html_static_path = ["_static"]
8182

83+
# sphinx-gallery configuration
84+
sphinx_gallery_conf = {
85+
"examples_dirs": "../examples/dynamo",
86+
"gallery_dirs": "tutorials/_rendered_examples/",
87+
}
88+
8289
# Setup the breathe extension
8390
breathe_projects = {"Torch-TensorRT": "./_tmp/xml"}
8491
breathe_default_project = "Torch-TensorRT"

docsrc/index.rst

+20-9
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,41 @@ Getting Started
3636
getting_started/getting_started_with_windows
3737

3838

39-
Tutorials
39+
User Guide
4040
------------
4141
* :ref:`creating_a_ts_mod`
4242
* :ref:`getting_started_with_fx`
4343
* :ref:`ptq`
4444
* :ref:`runtime`
45-
* :ref:`serving_torch_tensorrt_with_triton`
4645
* :ref:`use_from_pytorch`
4746
* :ref:`using_dla`
47+
48+
.. toctree::
49+
:caption: User Guide
50+
:maxdepth: 1
51+
:hidden:
52+
53+
user_guide/creating_torchscript_module_in_python
54+
user_guide/getting_started_with_fx_path
55+
user_guide/ptq
56+
user_guide/runtime
57+
user_guide/use_from_pytorch
58+
user_guide/using_dla
59+
60+
Tutorials
61+
------------
62+
* :ref:`serving_torch_tensorrt_with_triton`
4863
* :ref:`notebooks`
64+
* :ref:`dynamo_compile`
4965

5066
.. toctree::
5167
:caption: Tutorials
52-
:maxdepth: 1
68+
:maxdepth: 3
5369
:hidden:
5470

55-
tutorials/creating_torchscript_module_in_python
56-
tutorials/getting_started_with_fx_path
57-
tutorials/ptq
58-
tutorials/runtime
5971
tutorials/serving_torch_tensorrt_with_triton
60-
tutorials/use_from_pytorch
61-
tutorials/using_dla
6272
tutorials/notebooks
73+
tutorials/_rendered_examples/index
6374

6475
Python API Documenation
6576
------------------------

docsrc/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
sphinx==4.5.0
2+
sphinx-gallery==0.13.0
23
breathe==4.33.1
34
exhale==0.3.1
45
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

examples/dynamo/README.rst

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
.. _dynamo_compile:
2+
3+
Dynamo Compile Examples
4+
================
5+
6+
This document contains examples of usage of the `torch_tensorrt.dynamo.compile` API which integrates with `torch.compile` functionality
7+
8+
Overview of Available Scripts
9+
-----------------------------------------------
10+
- `dynamo_compile_resnet_example.py <./dynamo_compile_resnet_example.html>`_: Example showcasing compilation of ResNet model
11+
- `dynamo_compile_transformers_example.py <./dynamo_compile_transformers_example.html>`_: Example showcasing compilation of transformer-based model
12+
- `dynamo_compile_advanced_usage.py <./dynamo_compile_advanced_usage.html>`_: Advanced usage including making a custom backend to use directly with the `torch.compile` API
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Dynamo Compile Advanced Usage
3+
=========================
4+
5+
This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API."""
6+
7+
# %%
8+
# Imports and Model Definition
9+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10+
11+
import torch
12+
from torch_tensorrt.fx.lower_setting import LowerPrecision
13+
14+
# %%
15+
16+
# We begin by defining a model
17+
class Model(torch.nn.Module):
18+
def __init__(self) -> None:
19+
super().__init__()
20+
self.relu = torch.nn.ReLU()
21+
22+
def forward(self, x: torch.Tensor, y: torch.Tensor):
23+
x_out = self.relu(x)
24+
y_out = self.relu(y)
25+
x_y_out = x_out + y_out
26+
return torch.mean(x_y_out)
27+
28+
29+
# %%
30+
# Compilation with `torch.compile` Using Default Settings
31+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
32+
33+
# Define sample float inputs and initialize model
34+
sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]
35+
model = Model().eval().cuda()
36+
37+
# %%
38+
39+
# Next, we compile the model using torch.compile
40+
# For the default settings, we can simply call torch.compile
41+
# with the backend "torch_tensorrt", and run the model on an
42+
# input to cause compilation, as so:
43+
optimized_model = torch.compile(model, backend="torch_tensorrt")
44+
optimized_model(*sample_inputs)
45+
46+
# %%
47+
# Compilation with `torch.compile` Using Custom Settings
48+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
49+
50+
# First, we use Torch utilities to clean up the workspace
51+
# after the previous compile invocation
52+
torch._dynamo.reset()
53+
54+
# Define sample half inputs and initialize model
55+
sample_inputs_half = [
56+
torch.rand((5, 7)).half().cuda(),
57+
torch.rand((5, 7)).half().cuda(),
58+
]
59+
model_half = Model().eval().cuda()
60+
61+
# %%
62+
63+
# If we want to customize certain options in the backend,
64+
# but still use the torch.compile call directly, we can provide
65+
# custom options to the backend via the "options" keyword
66+
# which takes in a dictionary mapping options to values.
67+
#
68+
# For accepted backend options, see the CompilationSettings dataclass:
69+
# py/torch_tensorrt/dynamo/backend/_settings.py
70+
backend_kwargs = {
71+
"precision": LowerPrecision.FP16,
72+
"debug": True,
73+
"min_block_size": 2,
74+
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
75+
"optimization_level": 4,
76+
"use_experimental_rt": True,
77+
}
78+
79+
# Run the model on an input to cause compilation, as so:
80+
optimized_model_custom = torch.compile(
81+
model_half, backend="torch_tensorrt", options=backend_kwargs
82+
)
83+
optimized_model_custom(*sample_inputs_half)
84+
85+
# %%
86+
# Cleanup
87+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
88+
89+
# Finally, we use Torch utilities to clean up the workspace
90+
torch._dynamo.reset()
91+
92+
with torch.no_grad():
93+
torch.cuda.empty_cache()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Dynamo Compile ResNet Example
3+
=========================
4+
5+
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model."""
6+
7+
# %%
8+
# Imports and Model Definition
9+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10+
11+
import torch
12+
import torch_tensorrt
13+
import torchvision.models as models
14+
15+
# %%
16+
17+
# Initialize model with half precision and sample inputs
18+
model = models.resnet18(pretrained=True).half().eval().to("cuda")
19+
inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()]
20+
21+
# %%
22+
# Optional Input Arguments to `torch_tensorrt.dynamo.compile`
23+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
24+
25+
# Enabled precision for TensorRT optimization
26+
enabled_precisions = {torch.half}
27+
28+
# Whether to print verbose logs
29+
debug = True
30+
31+
# Workspace size for TensorRT
32+
workspace_size = 20 << 30
33+
34+
# Maximum number of TRT Engines
35+
# (Lower value allows more graph segmentation)
36+
min_block_size = 3
37+
38+
# Operations to Run in Torch, regardless of converter support
39+
torch_executed_ops = {}
40+
41+
# %%
42+
# Compilation with `torch_tensorrt.dynamo.compile`
43+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
44+
45+
# Build and compile the model with torch.compile, using Torch-TensorRT backend
46+
optimized_model = torch_tensorrt.dynamo.compile(
47+
model,
48+
inputs,
49+
enabled_precisions=enabled_precisions,
50+
debug=debug,
51+
workspace_size=workspace_size,
52+
min_block_size=min_block_size,
53+
torch_executed_ops=torch_executed_ops,
54+
)
55+
56+
# %%
57+
# Equivalently, we could have run the above via the convenience frontend, as so:
58+
# `torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)`
59+
60+
# %%
61+
# Inference
62+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
63+
64+
# Does not cause recompilation (same batch size as input)
65+
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
66+
new_outputs = optimized_model(*new_inputs)
67+
68+
# %%
69+
70+
# Does cause recompilation (new batch size)
71+
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")]
72+
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)
73+
74+
# %%
75+
# Cleanup
76+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
77+
78+
# Finally, we use Torch utilities to clean up the workspace
79+
torch._dynamo.reset()
80+
81+
with torch.no_grad():
82+
torch.cuda.empty_cache()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Dynamo Compile Transformers Example
3+
=========================
4+
5+
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model."""
6+
7+
# %%
8+
# Imports and Model Definition
9+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10+
11+
import torch
12+
import torch_tensorrt
13+
from transformers import BertModel
14+
15+
# %%
16+
17+
# Initialize model with float precision and sample inputs
18+
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
19+
inputs = [
20+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
21+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
22+
]
23+
24+
25+
# %%
26+
# Optional Input Arguments to `torch_tensorrt.dynamo.compile`
27+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
28+
29+
# Enabled precision for TensorRT optimization
30+
enabled_precisions = {torch.float}
31+
32+
# Whether to print verbose logs
33+
debug = True
34+
35+
# Workspace size for TensorRT
36+
workspace_size = 20 << 30
37+
38+
# Maximum number of TRT Engines
39+
# (Lower value allows more graph segmentation)
40+
min_block_size = 3
41+
42+
# Operations to Run in Torch, regardless of converter support
43+
torch_executed_ops = {}
44+
45+
# %%
46+
# Compilation with `torch_tensorrt.dynamo.compile`
47+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
48+
49+
# Build and compile the model with torch.compile, using tensorrt backend
50+
optimized_model = torch_tensorrt.dynamo.compile(
51+
model,
52+
inputs,
53+
enabled_precisions=enabled_precisions,
54+
debug=debug,
55+
workspace_size=workspace_size,
56+
min_block_size=min_block_size,
57+
torch_executed_ops=torch_executed_ops,
58+
)
59+
60+
# %%
61+
# Equivalently, we could have run the above via the convenience frontend, as so:
62+
# `torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)`
63+
64+
# %%
65+
# Inference
66+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67+
68+
# Does not cause recompilation (same batch size as input)
69+
new_inputs = [
70+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
71+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
72+
]
73+
new_outputs = optimized_model(*new_inputs)
74+
75+
# %%
76+
77+
# Does cause recompilation (new batch size)
78+
new_inputs = [
79+
torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
80+
torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
81+
]
82+
new_outputs = optimized_model(*new_inputs)
83+
84+
# %%
85+
# Cleanup
86+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
87+
88+
# Finally, we use Torch utilities to clean up the workspace
89+
torch._dynamo.reset()
90+
91+
with torch.no_grad():
92+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)