Skip to content

Commit 746a97a

Browse files
borisfompre-commit-ci[bot]KumoLiuyiheng-wang-nvbinliunls
authored
TRT support for MAISI (Project-MONAI#8153)
### Description Added trt_compile() support for Lists and Tuples in arguments for forward() - needed for MAISI. Did not add support for grouping return results yet - MAISI worked with explicit workaround unrolling the return results. ### Notes To successfully export MAISI, either latest Torch nightly is needed, or this patch needs to be applied to 24.09-based container: ``` --- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak 2024-10-09 01:38:04.920316673 +0000 +++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py 2024-10-09 01:38:25.228053951 +0000 @@ -148,7 +148,6 @@ is_causal and symbolic_helper._is_none(attn_mask) ), "is_causal and attn_mask cannot be set at the same time" - scale = symbolic_helper._maybe_get_const(scale, "f") if symbolic_helper._is_none(scale): scale = _attention_scale(g, query) ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Boris Fomitchev <[email protected]> Signed-off-by: Yiheng Wang <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: binliunls <[email protected]>
1 parent 941e739 commit 746a97a

File tree

8 files changed

+215
-98
lines changed

8 files changed

+215
-98
lines changed

Dockerfile

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \
4141
COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./
4242
COPY tests ./tests
4343
COPY monai ./monai
44+
45+
# TODO: remove this line and torch.patch for 24.11
46+
RUN patch -R -d /usr/local/lib/python3.10/dist-packages/torch/onnx/ < ./monai/torch.patch
47+
4448
RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \
4549
&& rm -rf build __pycache__
4650

@@ -57,4 +61,6 @@ RUN apt-get update \
5761
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
5862
ENV PATH=${PATH}:/opt/tools
5963
ENV POLYGRAPHY_AUTOINSTALL_DEPS=1
64+
65+
6066
WORKDIR /opt/monai

monai/networks/nets/vista3d.py

-1
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,6 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
641641
# [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
642642
masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
643643
masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)
644-
645644
return masks_embedding, class_embedding
646645

647646

monai/networks/trt_compiler.py

+158-54
Large diffs are not rendered by default.

monai/networks/utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,6 @@ def convert_to_onnx(
632632
use_trace: bool = True,
633633
do_constant_folding: bool = True,
634634
constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
635-
dynamo=False,
636635
**kwargs,
637636
):
638637
"""
@@ -673,6 +672,9 @@ def convert_to_onnx(
673672
# let torch.onnx.export to trace the model.
674673
mode_to_export = model
675674
torch_versioned_kwargs = kwargs
675+
if "dynamo" in kwargs and kwargs["dynamo"] and verify:
676+
torch_versioned_kwargs["verify"] = verify
677+
verify = False
676678
else:
677679
if not pytorch_after(1, 10):
678680
if "example_outputs" not in kwargs:
@@ -695,13 +697,13 @@ def convert_to_onnx(
695697
f = temp_file.name
696698
else:
697699
f = filename
698-
700+
print(f"torch_versioned_kwargs={torch_versioned_kwargs}")
699701
torch.onnx.export(
700702
mode_to_export,
701703
onnx_inputs,
702704
f=f,
703705
input_names=input_names,
704-
output_names=output_names,
706+
output_names=output_names or None,
705707
dynamic_axes=dynamic_axes,
706708
opset_version=opset_version,
707709
do_constant_folding=do_constant_folding,
@@ -715,6 +717,9 @@ def convert_to_onnx(
715717
fold_constants(onnx_model, size_threshold=constant_size_threshold)
716718

717719
if verify:
720+
if isinstance(inputs, dict):
721+
inputs = list(inputs.values())
722+
718723
if device is None:
719724
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
720725

monai/torch.patch

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py 2024-10-31 06:09:21.139938791 +0000
2+
+++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak 2024-10-31 06:01:50.207462739 +0000
3+
@@ -150,6 +150,7 @@
4+
), "is_causal and attn_mask cannot be set at the same time"
5+
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
6+
7+
+ scale = symbolic_helper._maybe_get_const(scale, "f")
8+
if symbolic_helper._is_none(scale):
9+
scale = _attention_scale(g, query)

monai/utils/module.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
649649
current_ver_string: if None, the current system GPU CUDA compute capability will be used.
650650
651651
Returns:
652-
True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
652+
True if the current system GPU CUDA compute capability is greater than the specified version.
653653
"""
654654
if current_ver_string is None:
655655
cuda_available = torch.cuda.is_available()
@@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
667667

668668
ver, has_ver = optional_import("packaging.version", name="parse")
669669
if has_ver:
670-
return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
670+
return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore
671671
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
672672
while len(parts) < 2:
673673
parts += ["0"]
674674
c_major, c_minor = parts[:2]
675675
c_mn = int(c_major), int(c_minor)
676676
mn = int(major), int(minor)
677-
return c_mn > mn
677+
return c_mn >= mn

tests/test_trt_compile.py

+30-36
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,32 @@
1919

2020
from monai.handlers import TrtHandler
2121
from monai.networks import trt_compile
22-
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
22+
from monai.networks.nets import cell_sam_wrapper, vista3d132
2323
from monai.utils import min_version, optional_import
24-
from tests.utils import (
25-
SkipIfAtLeastPyTorchVersion,
26-
SkipIfBeforeComputeCapabilityVersion,
27-
skip_if_no_cuda,
28-
skip_if_quick,
29-
skip_if_windows,
30-
)
24+
from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
3125

3226
trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
27+
torch_tensorrt, torch_trt_imported = optional_import("torch_tensorrt")
3328
polygraphy, polygraphy_imported = optional_import("polygraphy")
3429
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
3530

3631
TEST_CASE_1 = ["fp32"]
3732
TEST_CASE_2 = ["fp16"]
3833

3934

35+
class ListAdd(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
39+
def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = 0.1):
40+
y1 = y.clone()
41+
x1 = x.copy()
42+
z1 = z + y
43+
for xi in x:
44+
y1 = y1 + xi + bs
45+
return x1, [y1, z1], y1 + z1
46+
47+
4048
@skip_if_windows
4149
@skip_if_no_cuda
4250
@skip_if_quick
@@ -53,7 +61,7 @@ def tearDown(self):
5361
if current_device != self.gpu_device:
5462
torch.cuda.set_device(self.gpu_device)
5563

56-
@SkipIfAtLeastPyTorchVersion((2, 4, 1))
64+
@unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
5765
def test_handler(self):
5866
from ignite.engine import Engine
5967

@@ -74,29 +82,19 @@ def test_handler(self):
7482
net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda"))
7583
self.assertIsNotNone(net1._trt_compiler.engine)
7684

77-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
78-
def test_unet_value(self, precision):
79-
model = UNet(
80-
spatial_dims=3,
81-
in_channels=1,
82-
out_channels=2,
83-
channels=(2, 2, 4, 8, 4),
84-
strides=(2, 2, 2, 2),
85-
num_res_units=2,
86-
norm="batch",
87-
).cuda()
85+
def test_lists(self):
86+
model = ListAdd().cuda()
87+
8888
with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
89-
model.eval()
90-
input_example = torch.randn(2, 1, 96, 96, 96).cuda()
91-
output_example = model(input_example)
92-
args: dict = {"builder_optimization_level": 1}
93-
trt_compile(
94-
model,
95-
f"{tmpdir}/test_unet_trt_compile",
96-
args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]},
97-
)
89+
args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}}
90+
x = torch.randn(1, 16).to("cuda")
91+
y = torch.randn(1, 16).to("cuda")
92+
z = torch.randn(1, 16).to("cuda")
93+
input_example = ([x, y, z], y.clone(), z.clone())
94+
output_example = model(*input_example)
95+
trt_compile(model, f"{tmpdir}/test_lists", args=args)
9896
self.assertIsNone(model._trt_compiler.engine)
99-
trt_output = model(input_example)
97+
trt_output = model(*input_example)
10098
# Check that lazy TRT build succeeded
10199
self.assertIsNotNone(model._trt_compiler.engine)
102100
torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)
@@ -109,11 +107,7 @@ def test_cell_sam_wrapper_value(self, precision):
109107
model.eval()
110108
input_example = torch.randn(1, 3, 128, 128).to("cuda")
111109
output_example = model(input_example)
112-
trt_compile(
113-
model,
114-
f"{tmpdir}/test_cell_sam_wrapper_trt_compile",
115-
args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
116-
)
110+
trt_compile(model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", args={"precision": precision})
117111
self.assertIsNone(model._trt_compiler.engine)
118112
trt_output = model(input_example)
119113
# Check that lazy TRT build succeeded
@@ -130,7 +124,7 @@ def test_vista3d(self, precision):
130124
model = trt_compile(
131125
model,
132126
f"{tmpdir}/test_vista3d_trt_compile",
133-
args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
127+
args={"precision": precision, "dynamic_batchsize": [1, 2, 4]},
134128
submodule=["image_encoder.encoder", "class_head"],
135129
)
136130
self.assertIsNotNone(model.image_encoder.encoder._trt_compiler)

tests/test_version_after.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
TEST_CASES_SM = [
4040
# (major, minor, sm, expected)
41-
(6, 1, "6.1", True),
41+
(6, 1, "6.1", False),
4242
(6, 1, "6.0", False),
4343
(6, 0, "8.6", True),
4444
(7, 0, "8", True),

0 commit comments

Comments
 (0)