Skip to content

Commit 106bca7

Browse files
committed
Use len(kernel_shape) + 1 instead of unbatched_rank
1 parent 2ec7678 commit 106bca7

File tree

1 file changed

+4
-5
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+4
-5
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ def _aten_avg_pool_onnx(
121121
pads: Sequence[int],
122122
ceil_mode: bool,
123123
count_include_pad: bool,
124-
unbatched_rank: int,
125124
) -> TFloat:
126-
self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank
125+
self_rank_is_unbatched_rank = len(self.shape) == len(kernel_shape) + 1
127126
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
128127
self = op.Unsqueeze(self, [0])
129128

@@ -162,7 +161,7 @@ def aten_avg_pool1d(
162161
expand_size, kernel_size, stride, padding
163162
)
164163

165-
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad, 2)
164+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
166165

167166

168167
@torch_op("aten::avg_pool2d", trace_only=True)
@@ -199,7 +198,7 @@ def aten_avg_pool2d(
199198
# S is stride size, in this case S=4,
200199
# S may dup lot of times according to the image size
201200

202-
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad, 3)
201+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
203202

204203

205204
def aten_avg_pool2d_backward(
@@ -251,7 +250,7 @@ def aten_avg_pool3d(
251250
# S is stride size, in this case S=4,
252251
# S may dup lot of times according to the image size
253252

254-
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad, 4)
253+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
255254

256255

257256
def aten_avg_pool3d_backward(

0 commit comments

Comments
 (0)