Skip to content

Commit 9cced93

Browse files
cherry pick 2740 to release2.4 branch. (#3033)
1 parent 9a7f272 commit 9cced93

File tree

2 files changed

+156
-3
lines changed

2 files changed

+156
-3
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+43-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import torch
55
from torch._decomp import register_decomposition
66
from torch._ops import OpOverload
7+
from torch_tensorrt.dynamo._defaults import default_device
78
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
9+
from torch_tensorrt.dynamo.utils import to_torch_device
810

911
from ._decomposition_groups import (
1012
ENABLED_TORCH_DECOMPOSITIONS,
@@ -166,7 +168,7 @@ def var_decomposition(
166168
@register_torch_trt_decomposition(
167169
torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS
168170
)
169-
def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
171+
def empty_permuted_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
170172
empty_size = args[0]
171173
empty_permute = args[1]
172174
perm = [0] * len(empty_size)
@@ -185,7 +187,7 @@ def slice_scatter_decomposition(
185187
start: Optional[int] = None,
186188
end: Optional[int] = None,
187189
step: Optional[int] = None,
188-
):
190+
) -> torch.Tensor:
189191
dim_size = input_tensor.shape[dim]
190192
start = get_positive_dim(start, input_tensor.shape[dim])
191193
if end is None:
@@ -230,12 +232,50 @@ def select_scatter_decomposition(
230232
@register_torch_trt_decomposition(
231233
torch.ops.aten.empty_strided.default, registry=TORCH_TRT_DECOMPOSITIONS
232234
)
233-
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
235+
def empty_strided_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
234236
empty_size = args[0]
235237
empty_stride = args[1]
236238
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
237239

238240

241+
@register_torch_trt_decomposition(
242+
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
243+
)
244+
def scatter_add_decomposition(
245+
input_tensor: torch.Tensor,
246+
dim: int,
247+
index: torch.Tensor,
248+
src_tensor: torch.Tensor,
249+
) -> torch.Tensor:
250+
scatter_add_tensor = input_tensor
251+
src_shape = list(src_tensor.shape)
252+
src_dim = src_shape[dim]
253+
for i in range(0, src_dim):
254+
to_scatter_tensor = torch.zeros_like(input_tensor)
255+
256+
# index and src slice
257+
src_slice = torch.select(src_tensor, dim, i)
258+
index_slice = torch.select(index, dim, i)
259+
260+
# unsqueeze src and index in dim
261+
src_slice = torch.unsqueeze(src_slice, dim)
262+
index_slice = torch.unsqueeze(index_slice, dim)
263+
264+
# moving tensor to default device
265+
device = to_torch_device(default_device())
266+
scatter_add_tensor = scatter_add_tensor.to(device)
267+
to_scatter_tensor = to_scatter_tensor.to(device)
268+
index_slice = index_slice.to(device)
269+
src_slice = src_slice.to(device)
270+
271+
scatter_add_tensor = torch.add(
272+
scatter_add_tensor,
273+
torch.scatter(to_scatter_tensor, dim, index_slice, src_slice),
274+
)
275+
276+
return scatter_add_tensor
277+
278+
239279
def get_decompositions(
240280
enable_experimental_decompositions: bool = False,
241281
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

+113
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,119 @@ def forward(self, input):
962962
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
963963
)
964964

965+
@parameterized.expand(
966+
[
967+
(
968+
"scatter_add_zero_dim_indexOne_constant",
969+
0,
970+
torch.tensor([[0, 1, 2, 0]]).cuda(),
971+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(),
972+
{torch.ops.aten.add.Tensor},
973+
),
974+
(
975+
"scatter_add_zero_dim_indexTwo_constant",
976+
0,
977+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
978+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(),
979+
{torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src},
980+
),
981+
(
982+
"scatter_add_one_dim_indexOne_constant",
983+
1,
984+
torch.tensor([[0, 1, 2, 0]]).cuda(),
985+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(),
986+
{
987+
torch.ops.aten.add.Tensor,
988+
torch.ops.aten.scatter.src,
989+
torch.ops.aten.full_like.default,
990+
},
991+
),
992+
(
993+
"scatter_add_one_dim_indexTwo_constant",
994+
1,
995+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
996+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(),
997+
{
998+
torch.ops.aten.add.Tensor,
999+
torch.ops.aten.scatter.src,
1000+
torch.ops.aten.full_like.default,
1001+
},
1002+
),
1003+
(
1004+
"scatter_add_one_dim_indexTwo_constant",
1005+
1,
1006+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1], [3, 2, 1, 2]]).cuda(),
1007+
torch.tensor(
1008+
[[1, 2, 3, 1], [5, 6, 5, 5], [2, 4, 3, 2]], dtype=torch.int32
1009+
).cuda(),
1010+
{
1011+
torch.ops.aten.add.Tensor,
1012+
torch.ops.aten.scatter.src,
1013+
torch.ops.aten.full_like.default,
1014+
},
1015+
),
1016+
]
1017+
)
1018+
def test_scatter_add(self, _, dim, index, src, expected_ops_param):
1019+
class TestModule(torch.nn.Module):
1020+
def __init__(self):
1021+
super().__init__()
1022+
1023+
def forward(self, input):
1024+
return torch.ops.aten.scatter_add.default(input, dim, index, src)
1025+
1026+
# Operations expected to be included in the traced graph after decompositions
1027+
expected_ops = expected_ops_param
1028+
unexpected_ops = {torch.ops.aten.scatter_add.default}
1029+
1030+
input = torch.zeros(3, 5, dtype=torch.int32).cuda()
1031+
inputs = [input]
1032+
1033+
fx_graph = torch.fx.symbolic_trace(TestModule())
1034+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
1035+
fx_graph,
1036+
inputs,
1037+
expected_ops=expected_ops,
1038+
unexpected_ops=unexpected_ops,
1039+
min_block_size=2,
1040+
)
1041+
1042+
self.assertEquals(
1043+
len(expected_ops_unseen),
1044+
0,
1045+
f"The following expected ops were not encountered: {expected_ops_unseen}",
1046+
)
1047+
1048+
self.assertEquals(
1049+
len(unexpected_ops_seen),
1050+
0,
1051+
f"The following expected ops were not encountered: {unexpected_ops_seen}",
1052+
)
1053+
1054+
torch._dynamo.reset()
1055+
1056+
# Validate that the results between Torch and Torch-TRT are similar
1057+
optimized_model = torch_tensorrt.compile(
1058+
fx_graph,
1059+
"torch_compile",
1060+
inputs,
1061+
min_block_size=1,
1062+
truncate_double=True,
1063+
pass_through_build_failures=True,
1064+
)
1065+
optimized_model_results = optimized_model(*inputs).detach().cpu()
1066+
torch_model_results = fx_graph(*inputs).detach().cpu()
1067+
1068+
max_diff = float(
1069+
torch.max(torch.abs(optimized_model_results - torch_model_results))
1070+
)
1071+
self.assertAlmostEqual(
1072+
max_diff,
1073+
0,
1074+
DECIMALS_OF_AGREEMENT,
1075+
f"Scatter_add TRT outputs don't match with the original model.",
1076+
)
1077+
9651078

9661079
if __name__ == "__main__":
9671080
run_tests()

0 commit comments

Comments
 (0)