Skip to content

Commit 6571252

Browse files
peri044gs-olive
andauthored
feat: Implement Dynamic shapes + fallback support for export path (#2271)
Signed-off-by: Dheeraj Peri <[email protected]> Co-authored-by: gs-olive <[email protected]>
1 parent 8c25baf commit 6571252

21 files changed

+643
-82
lines changed

.github/workflows/build-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ jobs:
141141
cd tests/py/dynamo
142142
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
143143
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
144+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_dyn_models.py
144145
popd
145146
146147
tests-py-torch-compile-be:

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class DataType {
6060
enum Value : int8_t {
6161
/// INT64
6262
kLong,
63+
/// FP64
64+
kDouble,
6365
/// FP32
6466
kFloat,
6567
/// FP16

cpp/src/types.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ at::ScalarType toAtenDataType(DataType value) {
9797
return at::kInt;
9898
case DataType::kLong:
9999
return at::kLong;
100+
case DataType::kDouble:
101+
return at::kDouble;
100102
case DataType::kBool:
101103
return at::kBool;
102104
case DataType::kFloat:
@@ -119,7 +121,8 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
119121

120122
DataType::DataType(c10::ScalarType t) {
121123
TORCHTRT_CHECK(
122-
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool,
124+
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kDouble || t == at::kInt ||
125+
t == at::kBool,
123126
"Data type is unsupported (" << t << ")");
124127
switch (t) {
125128
case at::kHalf:
@@ -134,6 +137,9 @@ DataType::DataType(c10::ScalarType t) {
134137
case at::kLong:
135138
value = DataType::kLong;
136139
break;
140+
case at::kDouble:
141+
value = DataType::kDouble;
142+
break;
137143
case at::kBool:
138144
value = DataType::kBool;
139145
break;

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ User Guide
4242
* :ref:`getting_started_with_fx`
4343
* :ref:`ptq`
4444
* :ref:`runtime`
45+
* :ref:`dynamic_shapes`
4546
* :ref:`use_from_pytorch`
4647
* :ref:`using_dla`
4748

@@ -54,6 +55,7 @@ User Guide
5455
user_guide/getting_started_with_fx_path
5556
user_guide/ptq
5657
user_guide/runtime
58+
user_guide/dynamic_shapes
5759
user_guide/use_from_pytorch
5860
user_guide/using_dla
5961

docsrc/user_guide/dynamic_shapes.rst

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
.. _runtime:
2+
3+
Dynamic shapes with Torch-TensorRT
4+
====================================
5+
6+
By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly.
7+
However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model.
8+
In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for
9+
these range of input shapes. An example usage of static and dynamic shapes is as follows.
10+
11+
NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same.
12+
13+
.. code-block:: python
14+
15+
import torch
16+
import torch_tensorrt
17+
18+
model = MyModel().eval().cuda()
19+
# Compile with static shapes
20+
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)
21+
# or compile with dynamic shapes
22+
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
23+
opt_shape=[4, 3, 224, 224],
24+
max_shape=[8, 3, 224, 224],
25+
dtype=torch.float32)
26+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
27+
28+
Under the hood
29+
--------------
30+
31+
There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default).
32+
33+
- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs)
34+
35+
In the tracing phase, we use torch.export along with the constraints. In the case of
36+
dynamic shaped inputs, the range can be provided to the tracing via constraints. Please
37+
refer to this `docstring <https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/export/__init__.py#L372-L434>`_
38+
for detailed information on how to set constraints. In short, we create new inputs for
39+
torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take.
40+
Please take a look at ``aten_tracer.py`` file to understand how this works under the hood.
41+
42+
- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT)
43+
44+
In the conversion to TensorRT, we use the user provided dynamic shape inputs.
45+
We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the
46+
intermediate output shapes which can be used in case the graph has a mix of Pytorch
47+
and TensorRT submodules.
48+
49+
Custom Constraints
50+
------------------
51+
52+
Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``,
53+
Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows
54+
55+
.. code-block:: python
56+
57+
for dim in constraint_dims:
58+
if min_shape[dim] > 1:
59+
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
60+
if max_shape[dim] > 1:
61+
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])
62+
63+
Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
64+
For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs.
65+
66+
.. code-block:: python
67+
68+
constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0))
69+
70+
71+
If you have to provide any custom constraints to your model, the overall workflow for model compilation using ``ir=dynamo`` would involve a few steps.
72+
73+
.. code-block:: python
74+
75+
import torch
76+
import torch_tensorrt
77+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
78+
# Assume the model has two inputs
79+
model = MyModel()
80+
torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda()
81+
torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda()
82+
83+
dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14],
84+
opt_shape=[4, 14],
85+
max_shape=[8, 14],
86+
dtype=torch.int32),
87+
torch_tensorrt.Input(min_shape=[1, 14],
88+
opt_shape=[4, 14],
89+
max_shape=[8, 14],
90+
dtype=torch.int32)]
91+
92+
# Export the model with additional constraints
93+
constraints = []
94+
# The following constraints are automatically added by Torch-TensorRT in the
95+
# general case when you call torch_tensorrt.compile directly on MyModel()
96+
constraints.append(dynamic_dim(torch_input_1, 0) < 8)
97+
constraints.append(dynamic_dim(torch_input_2, 0) < 8)
98+
# This is an additional constraint as instructed by Torchdynamo
99+
constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0))
100+
with unittest.mock.patch(
101+
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
102+
):
103+
graph_module = export(
104+
model, (torch_input_1, torch_input_2), constraints=constraints
105+
).module()
106+
107+
# Use the dynamo.compile API
108+
trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec)
109+
110+
Limitations
111+
-----------
112+
113+
If there are operations in the graph that use the dynamic dimension of the input, Pytorch
114+
introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and
115+
the compilation results in undefined behavior. We plan to add support for these operators and implement
116+
robust support for shape tensors in the next release. Here is an example of the limitation described above
117+
118+
.. code-block:: python
119+
120+
import torch
121+
import torch_tensorrt
122+
123+
class MyModule(torch.nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
127+
128+
def forward(self, x):
129+
x = self.avgpool(x)
130+
out = torch.flatten(x, 1)
131+
return out
132+
133+
model = MyModel().eval().cuda()
134+
# Compile with dynamic shapes
135+
inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1),
136+
opt_shape=(4, 512, 1, 1),
137+
max_shape=(8, 512, 1, 1),
138+
dtype=torch.float32)
139+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
140+
141+
142+
The traced graph of `MyModule()` looks as follows
143+
144+
.. code-block:: python
145+
146+
Post export graph: graph():
147+
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
148+
%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {})
149+
%sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {})
150+
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {})
151+
return (view,)
152+
153+
154+
Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support
155+
which would be a part of our next release.
156+
157+
Workaround (BERT static compilation example)
158+
------------------------------------------
159+
160+
In the case where you encounter the issues mentioned in the **Limitations** section,
161+
you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs,
162+
we can pad them accordingly. This is only a workaround until we address the limitations.
163+
164+
.. code-block:: python
165+
166+
import torch
167+
import torch_tensorrt
168+
from transformers.utils.fx import symbolic_trace as transformers_trace
169+
170+
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
171+
172+
# Input sequence length is 20.
173+
input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")
174+
input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")
175+
176+
model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda()
177+
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
178+
model_outputs = model(input, input2)
179+
180+
# If you have a sequence of length 14, pad 6 zero tokens and run inference
181+
# or recompile for sequence length of 14.
182+
input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
183+
input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
184+
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
185+
model_outputs = model(input, input2)
186+
187+
188+
Dynamic shapes with ir=torch_compile
189+
------------------------------------
190+
191+
``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend
192+
configured to Tensorrt. In the case of ``ir=torch_compile``, users have to recompile for different input shapes.
193+
In the future, we plan to explore the option of compiling with dynamic shapes in the first execution of the model.
194+
195+
.. code-block:: python
196+
197+
import torch
198+
import torch_tensorrt
199+
200+
model = MyModel().eval().cuda()
201+
inputs = torch.randn((1, 3, 224, 224), dtype=float32)
202+
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs)
203+
# Compilation happens when you call the model
204+
trt_gm(inputs)
205+
206+
# Recompilation happens with modified batch size
207+
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
208+
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2)
209+
210+
211+
212+
213+
214+
215+
216+
217+
218+

py/torch_tensorrt/_Input.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class _ShapeMode(Enum):
4646
low_tensor_domain_incl: float = 0.0
4747
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
4848
torch_dtype: torch.dtype = torch.float32
49+
torch_tensor: torch.Tensor = None
4950

5051
def __init__(self, *args: Any, **kwargs: Any) -> None:
5152
"""__init__ Method for torch_tensorrt.Input
@@ -171,6 +172,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
171172

172173
self.tensor_domain = Input._parse_tensor_domain(domain)
173174

175+
if "torch_tensor" in kwargs:
176+
self.torch_tensor = kwargs["torch_tensor"]
177+
else:
178+
if self.shape_mode == Input._ShapeMode.DYNAMIC:
179+
self.torch_tensor = self.example_tensor("opt_shape")
180+
else:
181+
self.torch_tensor = self.example_tensor()
182+
174183
def __str__(self) -> str:
175184
if self.shape_mode == Input._ShapeMode.STATIC:
176185
return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format(
@@ -220,6 +229,8 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
220229
return _enums.dtype.half
221230
elif dtype == torch.float:
222231
return _enums.dtype.float
232+
elif dtype == torch.float64:
233+
return _enums.dtype.double
223234
elif dtype == torch.bool:
224235
return _enums.dtype.bool
225236
else:
@@ -249,6 +260,8 @@ def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
249260
return torch.float
250261
elif dtype == _enums.dtype.bool:
251262
return torch.bool
263+
elif dtype == _enums.dtype.double:
264+
return torch.float64
252265
else:
253266
# Default torch_dtype used in FX path
254267
return torch.float32
@@ -354,7 +367,7 @@ def from_tensor(
354367
)
355368
else torch.channels_last
356369
)
357-
return cls(shape=t.shape, dtype=t.dtype, format=frmt)
370+
return cls(shape=t.shape, dtype=t.dtype, format=frmt, torch_tensor=t)
358371

359372
@classmethod
360373
def from_tensors(

py/torch_tensorrt/_compile.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,19 +214,20 @@ def compile(
214214
)
215215
return compiled_fx_module
216216
elif target_ir == _IRType.dynamo:
217+
# Prepare torch and torchtrt inputs
217218
import collections.abc
218219

219-
from torch_tensorrt import Device
220-
from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device
220+
from torch_tensorrt.dynamo.utils import prepare_inputs
221221

222-
if not isinstance(inputs, collections.abc.Sequence):
223-
inputs = [inputs]
224-
device = kwargs.get("device", Device._current_device())
225-
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
226-
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
222+
if not isinstance(input_list, collections.abc.Sequence):
223+
input_list = [input_list]
224+
225+
# Export the module
226+
torchtrt_inputs = prepare_inputs(input_list)
227+
module = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
227228
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
228229
module,
229-
inputs=input_list,
230+
inputs=torchtrt_inputs,
230231
enabled_precisions=enabled_precisions_set,
231232
**kwargs,
232233
)

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ std::string to_str(DataType value) {
1818
return "Float";
1919
case DataType::kLong:
2020
return "Long";
21+
case DataType::kDouble:
22+
return "Double";
2123
default:
2224
return "Unknown data type";
2325
}
@@ -33,6 +35,8 @@ nvinfer1::DataType toTRTDataType(DataType value) {
3335
return nvinfer1::DataType::kINT32;
3436
case DataType::kLong:
3537
return nvinfer1::DataType::kINT32;
38+
case DataType::kDouble:
39+
return nvinfer1::DataType::kFLOAT;
3640
case DataType::kBool:
3741
return nvinfer1::DataType::kBOOL;
3842
case DataType::kFloat:
@@ -58,6 +62,8 @@ at::ScalarType toAtenDataType(DataType value) {
5862
return at::kBool;
5963
case DataType::kFloat:
6064
return at::kFloat;
65+
case DataType::kDouble:
66+
return at::kDouble;
6167
case DataType::kUnknown:
6268
return at::kFloat;
6369
default:

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace pyapi {
2727
return static_cast<int64_t>(field_name); \
2828
}
2929

30-
enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
30+
enum class DataType : int8_t { kLong, kDouble, kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
3131
std::string to_str(DataType value);
3232
nvinfer1::DataType toTRTDataType(DataType value);
3333
at::ScalarType toAtenDataType(DataType value);

0 commit comments

Comments
 (0)