Skip to content

Conversation

@wodesuck
Copy link
Contributor

Onnx's AveragePool require input shape as N,C,H,W, but torch accept both N,C,H,W and C,H,W. Unsqueeze if input is unbatched, just like what max_pool does.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Oct 20, 2025
@justinchuby
Copy link
Collaborator

@codecov
Copy link

codecov bot commented Oct 20, 2025

Codecov Report

❌ Patch coverage is 63.63636% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.45%. Comparing base (8a94ad6) to head (f5ec077).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 63.63% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2646      +/-   ##
==========================================
- Coverage   70.46%   70.45%   -0.01%     
==========================================
  Files         224      224              
  Lines       26572    26577       +5     
  Branches     2637     2639       +2     
==========================================
+ Hits        18723    18724       +1     
- Misses       6928     6930       +2     
- Partials      921      923       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wodesuck
Copy link
Contributor Author

@justinchuby Test added.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for unbatched input tensors in average pooling operations to match PyTorch's behavior. While ONNX's AveragePool requires NCHW format, PyTorch accepts both batched (NCHW) and unbatched (CHW) inputs. The changes handle unbatched inputs by automatically unsqueezing/squeezing dimensions, similar to the existing max_pool implementation.

Key Changes:

  • Introduced a helper function _aten_avg_pool_onnx that handles both batched and unbatched inputs
  • Refactored avg_pool1d, avg_pool2d, and avg_pool3d to use the new helper function
  • Added comprehensive tests covering all pooling dimensions with both batched and unbatched inputs

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxscript/function_libs/torch_lib/ops/nn.py Refactored avg_pool operations to support unbatched inputs via new helper function
tests/function_libs/torch_lib/e2e_ops_tests.py Added test cases for avg_pool operations with various input dimensions

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Minor comments and please check this page for lint: https://github.com/microsoft/onnxscript#coding-style

@titaiwangms
Copy link
Contributor

There is still something wrong with lint. Would you check?

@wodesuck
Copy link
Contributor Author

@titaiwangms Pylint says "torch.nn.functional.avg_pool1d is not callable", that's not true. I have run lintrunner locally without wrong, don't known why it still blame.

@titaiwangms
Copy link
Contributor

@titaiwangms Pylint says "torch.nn.functional.avg_pool1d is not callable", that's not true. I have run lintrunner locally without wrong, don't known why it still blame.

You can go ahead and disable it: To disable, use # pylint: disable=not-callable

@justinchuby justinchuby added this to the 0.5.5 milestone Oct 25, 2025
Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@justinchuby justinchuby merged commit 04a9da4 into microsoft:main Oct 27, 2025
30 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

3 participants