-
Notifications
You must be signed in to change notification settings - Fork 66
[torchlib] Implement torch.ops.prims.broadcast_in_dim.default #2382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: justinchuby <[email protected]>
Co-authored-by: justinchuby <[email protected]>
torch.ops.prims.broadcast_in_dim.default
# Convert broadcast_dimensions and input shape to tensors we can work with | ||
broadcast_dims_tensor = op.Constant(value_ints=list(broadcast_dimensions)) | ||
input_rank = op.Size(input_shape) | ||
indices = op.Range(op.Constant(value_int=0), input_rank, op.Constant(value_int=1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this known? Can it become list(range(input_rank))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right! While we can't use list(range(input_rank))
directly since input_rank
is a tensor, the underlying issue was that we were unnecessarily creating indices just to gather all elements of input_shape
in order. I've simplified this to use input_shape
directly instead of op.Gather(input_shape, indices)
. Fixed in 6567cfb.
# Scatter the input dimensions into the intermediate shape at the specified positions | ||
intermediate_shape = op.ScatterElements( | ||
ones, | ||
op.Unsqueeze(broadcast_dims_tensor, axes=[0]), | ||
op.Unsqueeze(op.Gather(input_shape, indices), axes=[0]), | ||
axis=0 | ||
) | ||
|
||
# Reshape the input tensor to the intermediate shape | ||
reshaped = op.Reshape(a, intermediate_shape) | ||
|
||
# Expand to the target shape | ||
return op.Expand(reshaped, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to simplify and avoid op.ScatterElements, and Gather, which are costly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've simplified the implementation to avoid ScatterElements by using Where operations with position masks instead. The new approach creates individual masks for each broadcast dimension and uses Where to selectively replace 1s with the actual input dimension values. This should be more efficient while maintaining the same functionality. Fixed in 5c35fe1.
Co-authored-by: justinchuby <[email protected]>
❌ 3 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
…ions Co-authored-by: justinchuby <[email protected]>
# while still leveraging compile-time knowledge of broadcast_dimensions | ||
|
||
input_shape = op.Shape(a) | ||
target_rank = op.Size(shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_rank = op.Size(shape) | |
target_rank = len(shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use target_rank = len(shape)
for compile-time efficiency instead of the runtime op.Size(shape)
operation. Fixed in 65ed52f.
ones = op.ConstantOfShape(op.Unsqueeze(target_rank, axes=[0]), value=op.Constant(value_int=1)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ones = op.ConstantOfShape(op.Unsqueeze(target_rank, axes=[0]), value=op.Constant(value_int=1)) | |
ones = [1] * target_rank | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced the runtime ConstantOfShape operation with a compile-time Python list [1] * target_rank
for better efficiency. Fixed in c81a4d4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
… review Co-authored-by: justinchuby <[email protected]>
…in review Co-authored-by: justinchuby <[email protected]>
What does broadcast_in_dim do? |
For example:
This gives more precise control over broadcasting compared to automatic broadcasting rules. |
This PR implements the missing
torch.ops.prims.broadcast_in_dim.default
operation that appears in BERT_pytorch and other PyTorch models.Overview
The
broadcast_in_dim
operation is a primitive that broadcasts a tensor to a target shape by specifying which dimensions of the output correspond to the input tensor dimensions. This is different from standard broadcasting operations.Implementation Details
Function signature:
Parameters:
a
: Input tensor to broadcastshape
: Target output shapebroadcast_dimensions
: Specifies which dimensions of the output shape correspond to the input tensor dimensionsExample:
Algorithm
op.ConstantOfShape
op.ScatterElements
to place input dimensions at specified broadcast positionsop.Reshape
op.Expand
The implementation uses only ONNX graph operations suitable for static compilation, avoiding Python loops that would break ONNXScript's graph construction.
Testing
@torch_op("prims::broadcast_in_dim", trace_only=True)
TracedOnnxFunction
onnxscript.function_libs.torch_lib.ops.prims
Fixes #2218.
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.