Skip to content

Fix padding_idx=None handling in aten_embedding_bag_padding_idx #2385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3089,16 +3089,19 @@
sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: int = -1,
padding_idx: Optional[int] = -1,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)

We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
assert padding_idx is not None, (
"padding_idx must not be None. This is likely a dispatcher error"
)
# If padding_idx is None, use regular embedding_bag without padding
if padding_idx is None:
return aten_embedding_bag(

Check warning on line 3101 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L3101

Added line #L3101 was not covered by tests
weight, indices, offsets, scale_grad_by_freq, mode, sparse,

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

per_sample_weights, include_last_offset
)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
Expand Down
Loading