@@ -121,9 +121,8 @@ def _aten_avg_pool_onnx(
121121    pads : Sequence [int ],
122122    ceil_mode : bool ,
123123    count_include_pad : bool ,
124-     unbatched_rank : int ,
125124) ->  TFloat :
126-     self_rank_is_unbatched_rank  =  len (self .shape ) ==  unbatched_rank 
125+     self_rank_is_unbatched_rank  =  len (self .shape ) ==  len ( kernel_shape )  +   1 
127126    if  self_rank_is_unbatched_rank :  # C,H,W -> N,C,H,W and N=1 
128127        self  =  op .Unsqueeze (self , [0 ])
129128
@@ -162,7 +161,7 @@ def aten_avg_pool1d(
162161        expand_size , kernel_size , stride , padding 
163162    )
164163
165-     return  _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad ,  2 )
164+     return  _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad )
166165
167166
168167@torch_op ("aten::avg_pool2d" , trace_only = True ) 
@@ -199,7 +198,7 @@ def aten_avg_pool2d(
199198    # S is stride size, in this case S=4, 
200199    # S may dup lot of times according to the image size 
201200
202-     return  _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad ,  3 )
201+     return  _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad )
203202
204203
205204def  aten_avg_pool2d_backward (
@@ -251,7 +250,7 @@ def aten_avg_pool3d(
251250    # S is stride size, in this case S=4, 
252251    # S may dup lot of times according to the image size 
253252
254-     return  _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad ,  4 )
253+     return  _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad )
255254
256255
257256def  aten_avg_pool3d_backward (
0 commit comments