Skip to content

Commit f7be3ae

Browse files
committed
Add conv2d e2e test from convnext model
1 parent 1b8d7e0 commit f7be3ae

File tree

1 file changed

+30
-0
lines changed
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

1 file changed

+30
-0
lines changed

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

+30
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,36 @@ def forward(self, inputVec, weight):
255255
def Convolution2DStaticModule_basic(module, tu: TestUtils):
256256
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
257257

258+
class Convolution2DNextStaticModule(torch.nn.Module):
259+
def __init__(self):
260+
super().__init__()
261+
262+
@export
263+
@annotate_args(
264+
[
265+
None,
266+
([1,80,72,72], torch.float32, True),
267+
([80,1,7,7], torch.float32, True),
268+
([80], torch.float32, True),
269+
]
270+
)
271+
def forward(self, inputVec, weight, bias):
272+
return torch.ops.aten.convolution(
273+
inputVec,
274+
weight,
275+
bias=bias,
276+
stride=[1, 1],
277+
padding=[3, 3],
278+
dilation=[1, 1],
279+
transposed=False,
280+
output_padding=[0, 0],
281+
groups=80,
282+
)
283+
284+
285+
@register_test_case(module_factory=lambda: Convolution2DNextStaticModule())
286+
def Convolution2DNextStaticModule_basic(module, tu: TestUtils):
287+
module.forward(tu.rand(1,80,72,72), tu.rand(80,1,7,7), tu.rand(80))
258288

259289
class Convolution2DStridedModule(torch.nn.Module):
260290
def __init__(self):

0 commit comments

Comments
 (0)