|
| 1 | +.. _runtime: |
| 2 | + |
| 3 | +Dynamic shapes with Torch-TensorRT |
| 4 | +==================================== |
| 5 | + |
| 6 | +By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly. |
| 7 | +However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model. |
| 8 | +In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for |
| 9 | +these range of input shapes. An example usage of static and dynamic shapes is as follows. |
| 10 | + |
| 11 | +NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same. |
| 12 | + |
| 13 | +.. code-block:: python |
| 14 | +
|
| 15 | + import torch |
| 16 | + import torch_tensorrt |
| 17 | +
|
| 18 | + model = MyModel().eval().cuda() |
| 19 | + # Compile with static shapes |
| 20 | + inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32) |
| 21 | + # or compile with dynamic shapes |
| 22 | + inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224], |
| 23 | + opt_shape=[4, 3, 224, 224], |
| 24 | + max_shape=[8, 3, 224, 224], |
| 25 | + dtype=torch.float32) |
| 26 | + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) |
| 27 | +
|
| 28 | +Under the hood |
| 29 | +-------------- |
| 30 | + |
| 31 | +There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default). |
| 32 | + |
| 33 | +- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs) |
| 34 | + |
| 35 | +In the tracing phase, we use torch.export along with the constraints. In the case of |
| 36 | +dynamic shaped inputs, the range can be provided to the tracing via constraints. Please |
| 37 | +refer to this `docstring <https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/export/__init__.py#L372-L434>`_ |
| 38 | +for detailed information on how to set constraints. In short, we create new inputs for |
| 39 | +torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take. |
| 40 | +Please take a look at ``aten_tracer.py`` file to understand how this works under the hood. |
| 41 | + |
| 42 | +- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT) |
| 43 | + |
| 44 | +In the conversion to TensorRT, we use the user provided dynamic shape inputs. |
| 45 | +We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the |
| 46 | +intermediate output shapes which can be used in case the graph has a mix of Pytorch |
| 47 | +and TensorRT submodules. |
| 48 | + |
| 49 | +Custom Constraints |
| 50 | +------------------ |
| 51 | + |
| 52 | +Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``, |
| 53 | +Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows |
| 54 | + |
| 55 | +.. code-block:: python |
| 56 | +
|
| 57 | + for dim in constraint_dims: |
| 58 | + if min_shape[dim] > 1: |
| 59 | + constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) |
| 60 | + if max_shape[dim] > 1: |
| 61 | + constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) |
| 62 | +
|
| 63 | +Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them. |
| 64 | +For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs. |
| 65 | + |
| 66 | +.. code-block:: python |
| 67 | +
|
| 68 | + constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0)) |
| 69 | +
|
| 70 | +
|
| 71 | +If you have to provide any custom constraints to your model, the overall workflow for model compilation using ``ir=dynamo`` would involve a few steps. |
| 72 | + |
| 73 | +.. code-block:: python |
| 74 | +
|
| 75 | + import torch |
| 76 | + import torch_tensorrt |
| 77 | + from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions |
| 78 | + # Assume the model has two inputs |
| 79 | + model = MyModel() |
| 80 | + torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda() |
| 81 | + torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda() |
| 82 | +
|
| 83 | + dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14], |
| 84 | + opt_shape=[4, 14], |
| 85 | + max_shape=[8, 14], |
| 86 | + dtype=torch.int32), |
| 87 | + torch_tensorrt.Input(min_shape=[1, 14], |
| 88 | + opt_shape=[4, 14], |
| 89 | + max_shape=[8, 14], |
| 90 | + dtype=torch.int32)] |
| 91 | +
|
| 92 | + # Export the model with additional constraints |
| 93 | + constraints = [] |
| 94 | + # The following constraints are automatically added by Torch-TensorRT in the |
| 95 | + # general case when you call torch_tensorrt.compile directly on MyModel() |
| 96 | + constraints.append(dynamic_dim(torch_input_1, 0) < 8) |
| 97 | + constraints.append(dynamic_dim(torch_input_2, 0) < 8) |
| 98 | + # This is an additional constraint as instructed by Torchdynamo |
| 99 | + constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0)) |
| 100 | + with unittest.mock.patch( |
| 101 | + "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) |
| 102 | + ): |
| 103 | + graph_module = export( |
| 104 | + model, (torch_input_1, torch_input_2), constraints=constraints |
| 105 | + ).module() |
| 106 | +
|
| 107 | + # Use the dynamo.compile API |
| 108 | + trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec) |
| 109 | +
|
| 110 | +Limitations |
| 111 | +----------- |
| 112 | + |
| 113 | +If there are operations in the graph that use the dynamic dimension of the input, Pytorch |
| 114 | +introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and |
| 115 | +the compilation results in undefined behavior. We plan to add support for these operators and implement |
| 116 | +robust support for shape tensors in the next release. Here is an example of the limitation described above |
| 117 | + |
| 118 | +.. code-block:: python |
| 119 | +
|
| 120 | + import torch |
| 121 | + import torch_tensorrt |
| 122 | +
|
| 123 | + class MyModule(torch.nn.Module): |
| 124 | + def __init__(self): |
| 125 | + super().__init__() |
| 126 | + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| 127 | +
|
| 128 | + def forward(self, x): |
| 129 | + x = self.avgpool(x) |
| 130 | + out = torch.flatten(x, 1) |
| 131 | + return out |
| 132 | +
|
| 133 | + model = MyModel().eval().cuda() |
| 134 | + # Compile with dynamic shapes |
| 135 | + inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1), |
| 136 | + opt_shape=(4, 512, 1, 1), |
| 137 | + max_shape=(8, 512, 1, 1), |
| 138 | + dtype=torch.float32) |
| 139 | + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) |
| 140 | +
|
| 141 | +
|
| 142 | +The traced graph of `MyModule()` looks as follows |
| 143 | + |
| 144 | +.. code-block:: python |
| 145 | +
|
| 146 | + Post export graph: graph(): |
| 147 | + %arg0_1 : [num_users=2] = placeholder[target=arg0_1] |
| 148 | + %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {}) |
| 149 | + %sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {}) |
| 150 | + %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {}) |
| 151 | + return (view,) |
| 152 | +
|
| 153 | +
|
| 154 | +Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support |
| 155 | +which would be a part of our next release. |
| 156 | + |
| 157 | +Workaround (BERT static compilation example) |
| 158 | +------------------------------------------ |
| 159 | + |
| 160 | +In the case where you encounter the issues mentioned in the **Limitations** section, |
| 161 | +you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs, |
| 162 | +we can pad them accordingly. This is only a workaround until we address the limitations. |
| 163 | + |
| 164 | +.. code-block:: python |
| 165 | +
|
| 166 | + import torch |
| 167 | + import torch_tensorrt |
| 168 | + from transformers.utils.fx import symbolic_trace as transformers_trace |
| 169 | +
|
| 170 | + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() |
| 171 | +
|
| 172 | + # Input sequence length is 20. |
| 173 | + input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") |
| 174 | + input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") |
| 175 | + |
| 176 | + model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda() |
| 177 | + trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) |
| 178 | + model_outputs = model(input, input2) |
| 179 | + |
| 180 | + # If you have a sequence of length 14, pad 6 zero tokens and run inference |
| 181 | + # or recompile for sequence length of 14. |
| 182 | + input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") |
| 183 | + input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") |
| 184 | + trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) |
| 185 | + model_outputs = model(input, input2) |
| 186 | +
|
| 187 | +
|
| 188 | +Dynamic shapes with ir=torch_compile |
| 189 | +------------------------------------ |
| 190 | + |
| 191 | +``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend |
| 192 | +configured to Tensorrt. In the case of ``ir=torch_compile``, users have to recompile for different input shapes. |
| 193 | +In the future, we plan to explore the option of compiling with dynamic shapes in the first execution of the model. |
| 194 | + |
| 195 | +.. code-block:: python |
| 196 | +
|
| 197 | + import torch |
| 198 | + import torch_tensorrt |
| 199 | +
|
| 200 | + model = MyModel().eval().cuda() |
| 201 | + inputs = torch.randn((1, 3, 224, 224), dtype=float32) |
| 202 | + trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs) |
| 203 | + # Compilation happens when you call the model |
| 204 | + trt_gm(inputs) |
| 205 | +
|
| 206 | + # Recompilation happens with modified batch size |
| 207 | + inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32) |
| 208 | + trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2) |
| 209 | +
|
| 210 | +
|
| 211 | +
|
| 212 | +
|
| 213 | +
|
| 214 | +
|
| 215 | +
|
| 216 | +
|
| 217 | +
|
| 218 | +
|
0 commit comments