Skip to content

Commit b5dbc11

Browse files
authored
Merge pull request #2003 from pytorch/dynamo_compile_trt_module_next
feat: Add support for `TorchTensorRTModule` in Dynamo [1 / x]
2 parents 44e4ffa + 1756b12 commit b5dbc11

21 files changed

+362
-73
lines changed

examples/fx/fx2trt_example_next.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
99
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
1010
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
11-
from torch_tensorrt import TRTModuleNext as TRTModule, Device
11+
from torch_tensorrt.dynamo._TorchTensorRTModule import (
12+
TorchTensorRTModule as TRTModule,
13+
Device,
14+
)
1215

1316
# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
1417
# model to TensorRT via FX with existing FX based tooling. The general lowering flow

py/torch_tensorrt/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def _find_lib(name, paths):
9191
from torch_tensorrt import logging
9292
from torch_tensorrt._Input import Input
9393
from torch_tensorrt._Device import Device
94-
from torch_tensorrt._TRTModuleNext import TRTModuleNext
9594

9695
from torch_tensorrt import fx
9796

py/torch_tensorrt/_TRTModuleNext.py renamed to py/torch_tensorrt/dynamo/_TorchTensorRTModule.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
2-
from operator import truediv
3-
from typing import Any, List, Sequence, Tuple
2+
from typing import Any, List, Tuple
43

54
import torch
65
from torch_tensorrt import _C
@@ -9,8 +8,8 @@
98
logger = logging.getLogger(__name__)
109

1110

12-
class TRTModuleNext(torch.nn.Module):
13-
"""TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine.
11+
class TorchTensorRTModule(torch.nn.Module):
12+
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
1413
1514
This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
1615
FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
@@ -20,7 +19,7 @@ class TRTModuleNext(torch.nn.Module):
2019
The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
2120
the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))``
2221
23-
> Note: TRTModuleNext only supports engines built with explict batch
22+
> Note: TorchTensorRTModule only supports engines built with explict batch
2423
2524
Attributes:
2625
name (str): Name of module (for easier debugging)
@@ -37,7 +36,7 @@ def __init__(
3736
output_binding_names: List[str] = [],
3837
target_device: Device = Device._current_device(),
3938
):
40-
"""__init__ method for torch_tensorrt.TRTModuleNext
39+
"""__init__ method for torch_tensorrt.TorchTensorRTModule
4140
4241
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
4342
a PyTorch ``torch.nn.Module`` around it.
@@ -70,10 +69,7 @@ def __init__(
7069
)
7170
7271
"""
73-
logger.warning(
74-
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
75-
)
76-
super(TRTModuleNext, self).__init__()
72+
super(TorchTensorRTModule, self).__init__()
7773

7874
if not isinstance(serialized_engine, bytearray):
7975
ValueError("Expected serialized engine as bytearray")
@@ -89,8 +85,8 @@ def __init__(
8985
self.name + "_engine" if self.name != "" else "tensorrt_engine",
9086
target_device._to_serialized_rt_device(),
9187
serialized_engine,
92-
TRTModuleNext._pack_binding_names(self.input_binding_names),
93-
TRTModuleNext._pack_binding_names(self.output_binding_names),
88+
TorchTensorRTModule._pack_binding_names(self.input_binding_names),
89+
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
9490
]
9591
)
9692
else:
@@ -154,7 +150,7 @@ def is_non_tensor(i: Tuple[Any, bool]) -> bool:
154150

155151
non_tensors = [i[0] for i in filter(zip(inputs, types), is_non_tensor)]
156152
raise RuntimeError(
157-
f"TRTModuleNext expects a flattened list of tensors as input, found non tensors: {non_tensors}"
153+
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
158154
)
159155

160156
outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine)

py/torch_tensorrt/dynamo/backend/__init__.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch_tensorrt
55
from functools import partial
66

7-
from typing import Any, Sequence
7+
from typing import Any, Optional, Sequence
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

@@ -16,6 +16,10 @@
1616
WORKSPACE_SIZE,
1717
MIN_BLOCK_SIZE,
1818
PASS_THROUGH_BUILD_FAILURES,
19+
MAX_AUX_STREAMS,
20+
VERSION_COMPATIBLE,
21+
OPTIMIZATION_LEVEL,
22+
USE_PYTHON_RUNTIME,
1923
)
2024

2125

@@ -45,6 +49,10 @@ def compile(
4549
torch_executed_ops=[],
4650
torch_executed_modules=[],
4751
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
52+
max_aux_streams=MAX_AUX_STREAMS,
53+
version_compatible=VERSION_COMPATIBLE,
54+
optimization_level=OPTIMIZATION_LEVEL,
55+
use_python_runtime=USE_PYTHON_RUNTIME,
4856
**kwargs,
4957
):
5058
if debug:
@@ -91,6 +99,10 @@ def compile(
9199
min_block_size=min_block_size,
92100
torch_executed_ops=torch_executed_ops,
93101
pass_through_build_failures=pass_through_build_failures,
102+
max_aux_streams=max_aux_streams,
103+
version_compatible=version_compatible,
104+
optimization_level=optimization_level,
105+
use_python_runtime=use_python_runtime,
94106
**kwargs,
95107
)
96108

@@ -114,6 +126,10 @@ def create_backend(
114126
min_block_size: int = MIN_BLOCK_SIZE,
115127
torch_executed_ops: Sequence[str] = set(),
116128
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
129+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
130+
version_compatible: bool = VERSION_COMPATIBLE,
131+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
132+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
117133
**kwargs,
118134
):
119135
"""Create torch.compile backend given specified arguments
@@ -125,6 +141,13 @@ def create_backend(
125141
min_block_size: Minimum number of operators per TRT-Engine Block
126142
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
127143
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
144+
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
145+
version_compatible: Provide version forward-compatibility for engine plan files
146+
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
147+
searching for more optimization options. TRT defaults to 3
148+
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
149+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
150+
argument as None
128151
Returns:
129152
Backend for torch.compile
130153
"""
@@ -136,4 +159,9 @@ def create_backend(
136159
min_block_size=min_block_size,
137160
torch_executed_ops=torch_executed_ops,
138161
pass_through_build_failures=pass_through_build_failures,
162+
max_aux_streams=max_aux_streams,
163+
version_compatible=version_compatible,
164+
optimization_level=optimization_level,
165+
use_python_runtime=use_python_runtime,
166+
**kwargs,
139167
)

py/torch_tensorrt/dynamo/backend/_defaults.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66
WORKSPACE_SIZE = 0
77
MIN_BLOCK_SIZE = 5
88
PASS_THROUGH_BUILD_FAILURES = False
9+
MAX_AUX_STREAMS = None
10+
VERSION_COMPATIBLE = False
11+
OPTIMIZATION_LEVEL = None
12+
USE_PYTHON_RUNTIME = None
+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Sequence
2+
from typing import Optional, Sequence
33

44
from torch_tensorrt.fx.utils import LowerPrecision
55
from torch_tensorrt.dynamo.backend._defaults import (
@@ -8,14 +8,22 @@
88
WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
1010
PASS_THROUGH_BUILD_FAILURES,
11+
MAX_AUX_STREAMS,
12+
VERSION_COMPATIBLE,
13+
OPTIMIZATION_LEVEL,
14+
USE_PYTHON_RUNTIME,
1115
)
1216

1317

14-
@dataclass(frozen=True)
18+
@dataclass
1519
class CompilationSettings:
1620
precision: LowerPrecision = PRECISION
1721
debug: bool = DEBUG
1822
workspace_size: int = WORKSPACE_SIZE
1923
min_block_size: int = MIN_BLOCK_SIZE
2024
torch_executed_ops: Sequence[str] = field(default_factory=set)
2125
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
26+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
27+
version_compatible: bool = VERSION_COMPATIBLE
28+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
29+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME

py/torch_tensorrt/dynamo/backend/backends.py

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _compile_module(
139139
submodule,
140140
submodule_inputs,
141141
settings=settings,
142+
name=name,
142143
)
143144

144145
trt_modules[name] = trt_mod

py/torch_tensorrt/dynamo/backend/conversion.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Sequence, Union
22
import torch
3+
import io
34
from torch_tensorrt.fx.trt_module import TRTModule
4-
from torch_tensorrt import TRTModuleNext
55
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
66
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
77
InputTensorSpec,
@@ -15,12 +15,14 @@ def convert_module(
1515
module: torch.fx.GraphModule,
1616
inputs: Sequence[torch.Tensor],
1717
settings: CompilationSettings = CompilationSettings(),
18-
) -> Union[TRTModuleNext, TRTModule]:
18+
name: str = "",
19+
):
1920
"""Convert an FX module to a TRT module
2021
Args:
2122
module: FX GraphModule to convert
2223
inputs: Sequence of Tensors representing inputs to the module
2324
settings: Compilation settings
25+
name: TRT engine name
2426
Returns:
2527
TRTModule or TRTModuleNext
2628
"""
@@ -48,10 +50,27 @@ def convert_module(
4850
if settings.debug
4951
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
5052
),
53+
max_aux_streams=settings.max_aux_streams,
54+
version_compatible=settings.version_compatible,
55+
optimization_level=settings.optimization_level,
5156
)
5257

53-
return TRTModule(
54-
engine=interpreter_result.engine,
55-
input_names=interpreter_result.input_names,
56-
output_names=interpreter_result.output_names,
57-
)
58+
if settings.use_python_runtime:
59+
return TRTModule(
60+
engine=interpreter_result.engine,
61+
input_names=interpreter_result.input_names,
62+
output_names=interpreter_result.output_names,
63+
)
64+
65+
else:
66+
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
67+
68+
with io.BytesIO() as engine_bytes:
69+
engine_bytes.write(interpreter_result.engine.serialize())
70+
engine_str = engine_bytes.getvalue()
71+
return TorchTensorRTModule(
72+
serialized_engine=engine_str,
73+
name=name,
74+
input_binding_names=interpreter_result.input_names,
75+
output_binding_names=interpreter_result.output_names,
76+
)

0 commit comments

Comments
 (0)