Skip to content

Commit 7b825f5

Browse files
authored
chore: dynamic shape support for rsqrt/erf ops (#2929)
1 parent f3035be commit 7b825f5

File tree

4 files changed

+111
-6
lines changed

4 files changed

+111
-6
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def aten_ops_matmul(
560560
)
561561

562562

563-
@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default)
563+
@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default, supports_dynamic_shapes=True)
564564
def aten_ops_rsqrt(
565565
ctx: ConversionContext,
566566
target: Target,
@@ -634,7 +634,7 @@ def aten_ops_squeeze(
634634
return impl.squeeze.squeeze(ctx, target, SourceIR.ATEN, name, args[0], args[1])
635635

636636

637-
@dynamo_tensorrt_converter(torch.ops.aten.erf.default)
637+
@dynamo_tensorrt_converter(torch.ops.aten.erf.default, supports_dynamic_shapes=True)
638638
def aten_ops_erf(
639639
ctx: ConversionContext,
640640
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def rsqrt(
121121
name: str,
122122
input: TRTTensor,
123123
) -> TRTTensor:
124+
if (isinstance(input, TRTTensor)) and (
125+
input.dtype == trt.int8 or input.dtype == trt.int32
126+
):
127+
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_cast")
124128
sqrt_trt_output = convert_unary(
125129
ctx,
126130
target,

tests/py/dynamo/conversion/test_erf_aten.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,54 @@ def forward(self, input):
4141
inputs,
4242
)
4343

44+
@parameterized.expand(
45+
[
46+
(
47+
"2d_dim_dtype_half",
48+
(1, 1),
49+
(2, 2),
50+
(4, 4),
51+
torch.half,
52+
torch.half,
53+
),
54+
(
55+
"3d_dim_dtype_float",
56+
(1, 1, 1),
57+
(1, 2, 3),
58+
(3, 3, 3),
59+
torch.float,
60+
torch.float,
61+
),
62+
(
63+
"3d_dim_dtype_int32",
64+
(1, 1, 1),
65+
(2, 2, 4),
66+
(2, 3, 5),
67+
torch.int32,
68+
torch.float,
69+
),
70+
]
71+
)
72+
def test_dynamic_shape_erf(
73+
self, _, min_shape, opt_shape, max_shape, type, output_type
74+
):
75+
class erf(nn.Module):
76+
def forward(self, input):
77+
return torch.ops.aten.erf.default(input)
78+
79+
input_specs = [
80+
Input(
81+
min_shape=min_shape,
82+
opt_shape=opt_shape,
83+
max_shape=max_shape,
84+
dtype=type,
85+
),
86+
]
87+
88+
self.run_test_with_dynamic_shape(
89+
erf(), input_specs, output_dtypes=[output_type]
90+
)
91+
4492

4593
if __name__ == "__main__":
4694
run_tests()

tests/py/dynamo/conversion/test_rsqrt_aten.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,74 @@
1010
class TestRSqrtConverter(DispatchTestCase):
1111
@parameterized.expand(
1212
[
13-
("2d_dim_alpha", (2, 1), 2),
14-
("3d_dim_alpha", (2, 1, 2), 2),
13+
("2d_dim_float", (2, 1), torch.float),
14+
("3d_dim_float", (2, 1, 2), torch.float),
15+
("3d_dim_int32", (2, 1, 2), torch.int32),
1516
]
1617
)
17-
def test_rsqrt(self, _, x, alpha):
18+
def test_rsqrt(self, _, x, type):
1819
class rsqrt(nn.Module):
1920
def forward(self, input):
2021
return torch.ops.aten.rsqrt.default(input)
2122

22-
inputs = [torch.randn(x) + 1]
23+
if type == torch.int32:
24+
inputs = [torch.randint(255, x, dtype=torch.int32)]
25+
else:
26+
inputs = [torch.randn(x) + 1]
27+
2328
self.run_test(
2429
rsqrt(),
2530
inputs,
2631
)
2732

33+
@parameterized.expand(
34+
[
35+
(
36+
"2d_dim_dtype_half",
37+
(1, 1),
38+
(2, 2),
39+
(4, 4),
40+
torch.half,
41+
torch.half,
42+
),
43+
(
44+
"3d_dim_dtype_float",
45+
(1, 1, 1),
46+
(1, 2, 3),
47+
(3, 3, 3),
48+
torch.float,
49+
torch.float,
50+
),
51+
(
52+
"3d_dim_dtype_int32",
53+
(1, 1, 1),
54+
(2, 2, 4),
55+
(2, 3, 5),
56+
torch.int32,
57+
torch.float,
58+
),
59+
]
60+
)
61+
def test_dynamic_shape_rsqrt(
62+
self, _, min_shape, opt_shape, max_shape, type, output_type
63+
):
64+
class rsqrt(nn.Module):
65+
def forward(self, input):
66+
return torch.ops.aten.rsqrt.default(input)
67+
68+
input_specs = [
69+
Input(
70+
min_shape=min_shape,
71+
opt_shape=opt_shape,
72+
max_shape=max_shape,
73+
dtype=type,
74+
),
75+
]
76+
77+
self.run_test_with_dynamic_shape(
78+
rsqrt(), input_specs, output_dtypes=[output_type]
79+
)
80+
2881

2982
if __name__ == "__main__":
3083
run_tests()

0 commit comments

Comments
 (0)