Skip to content

Commit aca5d5c

Browse files
authored
fix: remove legacy conv converter (#3343)
1 parent eff9b26 commit aca5d5c

File tree

4 files changed

+81
-11
lines changed

4 files changed

+81
-11
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99
from torch.fx.node import Argument, Node, Target
10+
1011
from torch_tensorrt.dynamo._settings import CompilationSettings
1112
from torch_tensorrt.dynamo._SourceIR import SourceIR
1213
from torch_tensorrt.dynamo.conversion import impl
@@ -2446,15 +2447,8 @@ def aten_ops_le(
24462447
)
24472448

24482449

2449-
def conv_param_validator(
2450-
conv_node: Node, settings: Optional[CompilationSettings] = None
2451-
) -> bool:
2452-
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
2453-
2454-
24552450
@dynamo_tensorrt_converter(
24562451
torch.ops.aten.convolution.default,
2457-
capability_validator=conv_param_validator,
24582452
supports_dynamic_shapes=True,
24592453
)
24602454
@enforce_tensor_types(
@@ -2500,6 +2494,7 @@ def aten_ops_convolution(
25002494
stride=args[3],
25012495
padding=args[4],
25022496
dilation=args[5],
2497+
output_padding=args[7],
25032498
groups=args[8],
25042499
)
25052500

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

+23
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tensorrt as trt
77
import torch
88
from torch.fx.node import Target
9+
910
from torch_tensorrt.dynamo.conversion import impl
1011
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1112
from torch_tensorrt.dynamo.conversion.converter_utils import (
@@ -105,6 +106,9 @@ def deconvNd(
105106
padding = (padding,) if isinstance(padding, int) else padding
106107
stride = (stride,) if isinstance(stride, int) else stride
107108
dilation = (dilation,) if isinstance(dilation, int) else dilation
109+
output_padding = (
110+
(output_padding,) if isinstance(output_padding, int) else output_padding
111+
)
108112

109113
# Expand parameters manually for Conv1D computations
110114
if is_deconv1d:
@@ -113,6 +117,11 @@ def deconvNd(
113117
dilation = (
114118
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
115119
)
120+
output_padding = (
121+
(tuple(output_padding) + (0,))
122+
if output_padding is not None
123+
else output_padding
124+
)
116125

117126
set_layer_name(deconv_layer, target, name, source_ir)
118127

@@ -126,6 +135,20 @@ def deconvNd(
126135
if groups is not None:
127136
deconv_layer.num_groups = groups
128137

138+
ndims = len(padding)
139+
pre_padding_values = []
140+
post_padding_values = []
141+
142+
for dim in range(ndims):
143+
pre_padding = padding[dim]
144+
post_padding = padding[dim] - output_padding[dim]
145+
146+
pre_padding_values.append(pre_padding)
147+
post_padding_values.append(post_padding)
148+
149+
deconv_layer.pre_padding = tuple(pre_padding_values)
150+
deconv_layer.post_padding = tuple(post_padding_values)
151+
129152
# Handle quantization cases
130153
if scale is not None and zero_point is not None:
131154
# Assume the dtype of activation is torch.quint8

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
# @manual=//deeplearning/trt/python:py_tensorrt
1111
import tensorrt as trt
1212
import torch
13-
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
1413
from torch.fx.immutable_collections import immutable_list
1514
from torch.fx.node import Argument, Target
15+
16+
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
1617
from torch_tensorrt.fx.converters import acc_ops_converters
1718
from torch_tensorrt.fx.converters.impl import activation, convolution
1819

tests/py/dynamo/conversion/test_deconvolution_aten.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from parameterized import param, parameterized
33
from torch.testing._internal.common_utils import run_tests
4+
45
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
@@ -15,6 +16,21 @@ class TestDeconvolutionConverter(DispatchTestCase):
1516
param("non_zero_padding", 1, padding=1),
1617
param("dilation", 1, dilation=2),
1718
param("groups", 1, groups=3),
19+
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
20+
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
21+
param("output_padding_3", 3, stride=2, padding=3, output_padding=1),
22+
param("output_padding_4", 3, stride=3, padding=2, output_padding=1),
23+
param("output_padding_5", 3, stride=3, padding=3, output_padding=1),
24+
param("output_padding_6", 3, stride=3, padding=3, output_padding=2),
25+
param(
26+
"combined_params",
27+
3,
28+
stride=3,
29+
padding=3,
30+
dilation=2,
31+
groups=3,
32+
output_padding=2,
33+
),
1834
]
1935
)
2036
def test_deconv1d(
@@ -26,6 +42,7 @@ def test_deconv1d(
2642
dilation=1,
2743
groups=1,
2844
bias=True,
45+
output_padding=0,
2946
):
3047
class TestModule(torch.nn.Module):
3148
def __init__(self):
@@ -36,9 +53,10 @@ def __init__(self):
3653
kernel_size=kernel_size,
3754
stride=stride,
3855
padding=padding,
39-
dilation=dilation,
56+
output_padding=output_padding,
4057
groups=groups,
4158
bias=bias,
59+
dilation=dilation,
4260
)
4361

4462
def forward(self, x):
@@ -101,6 +119,22 @@ def forward(self, x):
101119
param("non_zero_padding", 1, padding=1),
102120
param("dilation", 1, dilation=2),
103121
param("groups", 1, groups=3),
122+
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
123+
param("output_padding_2", 3, stride=2, padding=1, output_padding=1),
124+
param("output_padding_3", 3, stride=2, padding=2, output_padding=1),
125+
param("output_padding_4", 3, stride=2, padding=3, output_padding=1),
126+
param("output_padding_5", 3, stride=3, padding=2, output_padding=1),
127+
param("output_padding_6", 3, stride=3, padding=3, output_padding=1),
128+
param("output_padding_7", 3, stride=3, padding=3, output_padding=2),
129+
param(
130+
"combined_params",
131+
3,
132+
stride=3,
133+
padding=3,
134+
dilation=2,
135+
groups=3,
136+
output_padding=2,
137+
),
104138
]
105139
)
106140
def test_deconv2d(
@@ -112,6 +146,7 @@ def test_deconv2d(
112146
dilation=1,
113147
groups=1,
114148
bias=True,
149+
output_padding=0,
115150
):
116151
class TestModule(torch.nn.Module):
117152
def __init__(self):
@@ -122,9 +157,10 @@ def __init__(self):
122157
kernel_size=kernel_size,
123158
stride=stride,
124159
padding=padding,
125-
dilation=dilation,
160+
output_padding=output_padding,
126161
groups=groups,
127162
bias=bias,
163+
dilation=dilation,
128164
)
129165

130166
def forward(self, x):
@@ -172,6 +208,19 @@ def forward(self, x):
172208
param("non_zero_padding", 1, padding=1),
173209
param("dilation", 1, dilation=2),
174210
param("groups", 1, groups=3),
211+
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
212+
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
213+
param("output_padding_3", 3, stride=3, padding=3, output_padding=1),
214+
param("output_padding_4", 3, stride=3, padding=3, output_padding=2),
215+
param(
216+
"combined_params",
217+
3,
218+
stride=3,
219+
padding=3,
220+
dilation=2,
221+
groups=3,
222+
output_padding=2,
223+
),
175224
]
176225
)
177226
def test_deconv3d(
@@ -183,6 +232,7 @@ def test_deconv3d(
183232
dilation=1,
184233
groups=1,
185234
bias=True,
235+
output_padding=0,
186236
):
187237
class TestModule(torch.nn.Module):
188238
def __init__(self):
@@ -193,9 +243,10 @@ def __init__(self):
193243
kernel_size=kernel_size,
194244
stride=stride,
195245
padding=padding,
196-
dilation=dilation,
246+
output_padding=output_padding,
197247
groups=groups,
198248
bias=bias,
249+
dilation=dilation,
199250
)
200251

201252
def forward(self, x):

0 commit comments

Comments
 (0)