@@ -255,6 +255,36 @@ def forward(self, inputVec, weight):
255
255
def Convolution2DStaticModule_basic (module , tu : TestUtils ):
256
256
module .forward (tu .rand (3 , 3 , 10 , 10 ), tu .rand (3 , 3 , 2 , 2 ))
257
257
258
+ class Convolution2DNextStaticModule (torch .nn .Module ):
259
+ def __init__ (self ):
260
+ super ().__init__ ()
261
+
262
+ @export
263
+ @annotate_args (
264
+ [
265
+ None ,
266
+ ([1 ,80 ,72 ,72 ], torch .float32 , True ),
267
+ ([80 ,1 ,7 ,7 ], torch .float32 , True ),
268
+ ([80 ], torch .float32 , True ),
269
+ ]
270
+ )
271
+ def forward (self , inputVec , weight , bias ):
272
+ return torch .ops .aten .convolution (
273
+ inputVec ,
274
+ weight ,
275
+ bias = bias ,
276
+ stride = [1 , 1 ],
277
+ padding = [3 , 3 ],
278
+ dilation = [1 , 1 ],
279
+ transposed = False ,
280
+ output_padding = [0 , 0 ],
281
+ groups = 80 ,
282
+ )
283
+
284
+
285
+ @register_test_case (module_factory = lambda : Convolution2DNextStaticModule ())
286
+ def Convolution2DNextStaticModule_basic (module , tu : TestUtils ):
287
+ module .forward (tu .rand (1 ,80 ,72 ,72 ), tu .rand (80 ,1 ,7 ,7 ), tu .rand (80 ))
258
288
259
289
class Convolution2DStridedModule (torch .nn .Module ):
260
290
def __init__ (self ):
0 commit comments