Skip to content

Commit 45c2726

Browse files
committed
feat: Add toggle for fallback to Inductor
1 parent 7062730 commit 45c2726

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
USE_FAST_PARTITIONER = True
1616
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
1717
REQUIRE_FULL_COMPILATION = False
18+
FALLBACK_TO_INDUCTOR = True
1819

1920

2021
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch_tensorrt.dynamo._defaults import (
77
DEBUG,
88
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
9+
FALLBACK_TO_INDUCTOR,
910
MAX_AUX_STREAMS,
1011
MIN_BLOCK_SIZE,
1112
OPTIMIZATION_LEVEL,
@@ -42,6 +43,8 @@ class CompilationSettings:
4243
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
4344
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
4445
or only a selected subset of them
46+
fallback_to_inductor (bool): Whether to fallback to inductor on Torch-TRT Compilation Errors.
47+
Is overridden by pass_through_build_failures.
4548
"""
4649

4750
precision: torch.dtype = PRECISION
@@ -59,3 +62,4 @@ class CompilationSettings:
5962
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
6063
device: Device = field(default_factory=default_device)
6164
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
65+
fallback_to_inductor: bool = FALLBACK_TO_INDUCTOR

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,17 @@ def _pretraced_backend(
8383
settings=settings,
8484
)
8585
return trt_compiled
86-
except AssertionError:
86+
except (AssertionError, RuntimeError):
8787
if not settings.pass_through_build_failures:
8888
logger.warning(
8989
"TRT conversion failed on the subgraph. See trace above. "
9090
+ "Returning GraphModule forward instead.",
9191
exc_info=True,
9292
)
93-
return gm.forward
93+
if settings.fallback_to_inductor:
94+
pass
95+
else:
96+
return gm
9497
else:
9598
logger.critical(
9699
"Halting compilation on build failure since "
@@ -100,3 +103,18 @@ def _pretraced_backend(
100103
+ "specify pass_through_build_failures=False."
101104
)
102105
raise
106+
107+
# If Inductor fallback is desired, attempt model compilation with inductor
108+
try:
109+
inductor_compiled = torch._inductor.compile(
110+
gm,
111+
sample_inputs,
112+
)
113+
return inductor_compiled
114+
except (AssertionError, RuntimeError):
115+
logger.warning(
116+
"Inductor compilation failed on the subgraph. See trace above. "
117+
+ "Returning GraphModule forward instead.",
118+
exc_info=True,
119+
)
120+
return gm

py/torch_tensorrt/dynamo/compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
DEBUG,
1616
DEVICE,
1717
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
18+
FALLBACK_TO_INDUCTOR,
1819
MAX_AUX_STREAMS,
1920
MIN_BLOCK_SIZE,
2021
OPTIMIZATION_LEVEL,
@@ -69,6 +70,7 @@ def compile(
6970
use_python_runtime: bool = USE_PYTHON_RUNTIME,
7071
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
7172
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
73+
fallback_to_inductor: bool = FALLBACK_TO_INDUCTOR,
7274
**kwargs: Any,
7375
) -> torch.fx.GraphModule:
7476
if debug:
@@ -84,7 +86,7 @@ def compile(
8486
"max_aux_streams, version_compatible, optimization_level, "
8587
"torch_executed_ops, pass_through_build_failures, "
8688
"use_fast_partitioner, enable_experimental_decompositions, "
87-
"require_full_compilation}"
89+
"require_full_compilation, fallback_to_inductor}"
8890
)
8991

9092
if not isinstance(inputs, collections.abc.Sequence):
@@ -130,6 +132,7 @@ def compile(
130132
"use_fast_partitioner": use_fast_partitioner,
131133
"enable_experimental_decompositions": enable_experimental_decompositions,
132134
"require_full_compilation": require_full_compilation,
135+
"fallback_to_inductor": fallback_to_inductor,
133136
}
134137

135138
settings = CompilationSettings(**compilation_options)

0 commit comments

Comments
 (0)