Skip to content

[torchlib] Implement torch.ops.prims.broadcast_in_dim.default #2382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Copilot
Copy link

@Copilot Copilot AI commented Jun 14, 2025

This PR implements the missing torch.ops.prims.broadcast_in_dim.default operation that appears in BERT_pytorch and other PyTorch models.

Overview

The broadcast_in_dim operation is a primitive that broadcasts a tensor to a target shape by specifying which dimensions of the output correspond to the input tensor dimensions. This is different from standard broadcasting operations.

Implementation Details

Function signature:

def prims_broadcast_in_dim(
    a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
) -> TensorType:

Parameters:

  • a: Input tensor to broadcast
  • shape: Target output shape
  • broadcast_dimensions: Specifies which dimensions of the output shape correspond to the input tensor dimensions

Example:

# Input tensor: [3, 4]
# Target shape: [2, 3, 5, 4] 
# broadcast_dimensions: [1, 3]
# Result: Input dimension 0 (size 3) maps to output dimension 1
#         Input dimension 1 (size 4) maps to output dimension 3
#         Output dimensions 0 and 2 are broadcasted (filled from size 1)

Algorithm

  1. Create intermediate shape with all 1s using op.ConstantOfShape
  2. Use op.ScatterElements to place input dimensions at specified broadcast positions
  3. Reshape input tensor to intermediate shape using op.Reshape
  4. Expand to final target shape using op.Expand

The implementation uses only ONNX graph operations suitable for static compilation, avoiding Python loops that would break ONNXScript's graph construction.

Testing

  • ✅ Function properly registered with @torch_op("prims::broadcast_in_dim", trace_only=True)
  • ✅ Logic validated with numpy equivalent operations
  • ✅ Imports correctly as TracedOnnxFunction
  • ✅ No syntax or compilation errors
  • ✅ Accessible from onnxscript.function_libs.torch_lib.ops.prims

Fixes #2218.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

@Copilot Copilot AI changed the title [WIP] [torchlib] Implement torch.ops.prims.broadcast_in_dim.default [torchlib] Implement torch.ops.prims.broadcast_in_dim.default Jun 14, 2025
@Copilot Copilot AI requested a review from justinchuby June 14, 2025 03:55
Copilot finished work on behalf of justinchuby June 14, 2025 03:55
# Convert broadcast_dimensions and input shape to tensors we can work with
broadcast_dims_tensor = op.Constant(value_ints=list(broadcast_dimensions))
input_rank = op.Size(input_shape)
indices = op.Range(op.Constant(value_int=0), input_rank, op.Constant(value_int=1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this known? Can it become list(range(input_rank))?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! While we can't use list(range(input_rank)) directly since input_rank is a tensor, the underlying issue was that we were unnecessarily creating indices just to gather all elements of input_shape in order. I've simplified this to use input_shape directly instead of op.Gather(input_shape, indices). Fixed in 6567cfb.

Comment on lines 199 to 211
# Scatter the input dimensions into the intermediate shape at the specified positions
intermediate_shape = op.ScatterElements(
ones,
op.Unsqueeze(broadcast_dims_tensor, axes=[0]),
op.Unsqueeze(op.Gather(input_shape, indices), axes=[0]),
axis=0
)

# Reshape the input tensor to the intermediate shape
reshaped = op.Reshape(a, intermediate_shape)

# Expand to the target shape
return op.Expand(reshaped, shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to simplify and avoid op.ScatterElements, and Gather, which are costly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simplified the implementation to avoid ScatterElements by using Where operations with position masks instead. The new approach creates individual masks for each broadcast dimension and uses Where to selectively replace 1s with the actual input dimension values. This should be more efficient while maintaining the same functionality. Fixed in 5c35fe1.

Copilot finished work on behalf of justinchuby June 14, 2025 14:30
Copy link

codecov bot commented Jun 14, 2025

❌ 3 Tests Failed:

Tests completed Failed Passed Skipped
15725 3 15722 1697
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0125_test_ai_onnx_ml_tree_ensemble_set_membership
Stack Traces | 0.007s run time
onnxscript/converter.py:460: in _eval_constant_expr
    return eval(cpl, self.globals, locals)  # pylint: disable=eval-used
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   NameError: name 'nan' is not defined

The above exception was the direct cause of the following exception:
..../test_ort_nightly/lib/python3.11.../site-packages/parameterized/parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/backend/onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/backend/onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
<frozen importlib._bootstrap>:1204: in _gcd_import
    ???
<frozen importlib._bootstrap>:1176: in _find_and_load
    ???
<frozen importlib._bootstrap>:1147: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:690: in _load_unlocked
    ???
..../test_ort_nightly/lib/python3.11.../_pytest/assertion/rewrite.py:186: in exec_module
    exec(co, module.__dict__)
tests/onnx_backend_test_code/test_ai_onnx_ml_tree_ensemble_set_membership.py:9: in <module>
    @script()
     ^^^^^^^^
onnxscript/main.py:94: in transform
    result = script_check(f_ast, opset, env, src, default_opset=default_opset)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/main.py:38: in script_check
    return convert.translate_function_def(f)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:1452: in translate_function_def
    fn_ir = self._translate_function_def_common(stmt)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:1439: in _translate_function_def_common
    self._translate_stmt(s, index_of_stmt=i)
onnxscript/converter.py:961: in _translate_stmt
    return self._translate_assign_stmt(node)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:1048: in _translate_assign_stmt
    assign(lhs, rhs)
onnxscript/converter.py:992: in assign
    t = self._translate_expr(rhs, lhs).name
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:546: in _translate_expr
    r = self._translate_call_expr(node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:825: in _translate_call_expr
    attrs = [
onnxscript/converter.py:826: in <listcomp>
    self._translate_attr(x, y, callee.op_schema.attributes[x])
onnxscript/converter.py:510: in _translate_attr
    val = self._eval_constant_expr(expr)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:462: in _eval_constant_expr
    raise NameError(
E   NameError: ERROR: Missing names, globals contains ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__file__', '__cached__', '__builtins__', '@py_builtins', '@pytest_ar', 'numpy', 'TensorProto', 'make_tensor', 'script', 'external_tensor', 'Opset', 'FLOAT', 'ai_onnx_ml5'], locals [].
E   at: Function 'bck_test_ai_onnx_ml_tree_ensemble_set_membership', line 3
E       Y = ai_onnx_ml5.TreeEnsemble(X, aggregate_function=1, leaf_targetids=[0, 1, 2, 3], leaf_weights=make_tensor("value", 1, dims=[4], vals=[1.0, 10.0, 1000.0, 100.0]), membership_values=make_tensor("value", 1, dims=[8], vals=[1.2000000476837158, 3.700000047683716, 8.0, 9.0, nan, 12.0, 7.0, nan]), n_targets=4, nodes_falseleafs=[1, 0, 1], nodes_falsenodeids=[2, 2, 3], nodes_featureids=[0, 0, 0], nodes_modes=make_tensor("value", 2, dims=[3], vals=[0, 6, 6]), nodes_splits=make_tensor("value", 1, dims=[3], vals=[11.0, 232344.0, nan]), nodes_trueleafs=[0, 1, 1], nodes_truenodeids=[1, 0, 1], post_transform=0, tree_roots=[0])
E                                                                                                                                                                                             ^
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0904_test_ai_onnx_ml_tree_ensemble_set_membership
Stack Traces | 0.01s run time
onnxscript/converter.py:460: in _eval_constant_expr
    return eval(cpl, self.globals, locals)  # pylint: disable=eval-used
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   NameError: name 'nan' is not defined

The above exception was the direct cause of the following exception:
..../test_ort_nightly/lib/python3.11.../site-packages/parameterized/parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/backend/onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/backend/onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../hostedtoolcache/Python/3.11.12.../x64/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
<frozen importlib._bootstrap>:1204: in _gcd_import
    ???
<frozen importlib._bootstrap>:1176: in _find_and_load
    ???
<frozen importlib._bootstrap>:1147: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:690: in _load_unlocked
    ???
..../test_ort_nightly/lib/python3.11.../_pytest/assertion/rewrite.py:186: in exec_module
    exec(co, module.__dict__)
tests/onnx_backend_test_code/test_ai_onnx_ml_tree_ensemble_set_membership.py:9: in <module>
    @script()
     ^^^^^^^^
onnxscript/main.py:94: in transform
    result = script_check(f_ast, opset, env, src, default_opset=default_opset)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/main.py:38: in script_check
    return convert.translate_function_def(f)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:1452: in translate_function_def
    fn_ir = self._translate_function_def_common(stmt)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:1439: in _translate_function_def_common
    self._translate_stmt(s, index_of_stmt=i)
onnxscript/converter.py:961: in _translate_stmt
    return self._translate_assign_stmt(node)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:1048: in _translate_assign_stmt
    assign(lhs, rhs)
onnxscript/converter.py:992: in assign
    t = self._translate_expr(rhs, lhs).name
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:546: in _translate_expr
    r = self._translate_call_expr(node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:825: in _translate_call_expr
    attrs = [
onnxscript/converter.py:826: in <listcomp>
    self._translate_attr(x, y, callee.op_schema.attributes[x])
onnxscript/converter.py:510: in _translate_attr
    val = self._eval_constant_expr(expr)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/converter.py:462: in _eval_constant_expr
    raise NameError(
E   NameError: ERROR: Missing names, globals contains ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__file__', '__cached__', '__builtins__', '@py_builtins', '@pytest_ar', 'numpy', 'TensorProto', 'make_tensor', 'script', 'external_tensor', 'Opset', 'FLOAT', 'ai_onnx_ml5'], locals [].
E   at: Function 'bck_test_ai_onnx_ml_tree_ensemble_set_membership', line 3
E       Y = ai_onnx_ml5.TreeEnsemble(X, aggregate_function=1, leaf_targetids=[0, 1, 2, 3], leaf_weights=make_tensor("value", 1, dims=[4], vals=[1.0, 10.0, 1000.0, 100.0]), membership_values=make_tensor("value", 1, dims=[8], vals=[1.2000000476837158, 3.700000047683716, 8.0, 9.0, nan, 12.0, 7.0, nan]), n_targets=4, nodes_falseleafs=[1, 0, 1], nodes_falsenodeids=[2, 2, 3], nodes_featureids=[0, 0, 0], nodes_modes=make_tensor("value", 2, dims=[3], vals=[0, 6, 6]), nodes_splits=make_tensor("value", 1, dims=[3], vals=[11.0, 232344.0, nan]), nodes_trueleafs=[0, 1, 1], nodes_truenodeids=[1, 0, 1], post_transform=0, tree_roots=[0])
E                                                                                                                                                                                             ^
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0026_test_ai_onnx_ml_tree_ensemble_set_membership
Stack Traces | 0.043s run time
onnxscript\converter.py:460: in _eval_constant_expr
    return eval(cpl, self.globals, locals)  # pylint: disable=eval-used
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   NameError: name 'nan' is not defined

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
<frozen importlib._bootstrap>:1204: in _gcd_import
    ???
<frozen importlib._bootstrap>:1176: in _find_and_load
    ???
<frozen importlib._bootstrap>:1147: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:690: in _load_unlocked
    ???
.nox\test_ort_nightly\Lib\site-packages\_pytest\assertion\rewrite.py:186: in exec_module
    exec(co, module.__dict__)
tests\onnx_backend_test_code\test_ai_onnx_ml_tree_ensemble_set_membership.py:9: in <module>
    @script()
     ^^^^^^^^
onnxscript\main.py:94: in transform
    result = script_check(f_ast, opset, env, src, default_opset=default_opset)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\main.py:38: in script_check
    return convert.translate_function_def(f)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\converter.py:1452: in translate_function_def
    fn_ir = self._translate_function_def_common(stmt)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\converter.py:1439: in _translate_function_def_common
    self._translate_stmt(s, index_of_stmt=i)
onnxscript\converter.py:961: in _translate_stmt
    return self._translate_assign_stmt(node)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\converter.py:1048: in _translate_assign_stmt
    assign(lhs, rhs)
onnxscript\converter.py:992: in assign
    t = self._translate_expr(rhs, lhs).name
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\converter.py:546: in _translate_expr
    r = self._translate_call_expr(node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\converter.py:825: in _translate_call_expr
    attrs = [
onnxscript\converter.py:826: in <listcomp>
    self._translate_attr(x, y, callee.op_schema.attributes[x])
onnxscript\converter.py:510: in _translate_attr
    val = self._eval_constant_expr(expr)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\converter.py:462: in _eval_constant_expr
    raise NameError(
E   NameError: ERROR: Missing names, globals contains ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__file__', '__cached__', '__builtins__', '@py_builtins', '@pytest_ar', 'numpy', 'TensorProto', 'make_tensor', 'script', 'external_tensor', 'Opset', 'FLOAT', 'ai_onnx_ml5'], locals [].
E   at: Function 'bck_test_ai_onnx_ml_tree_ensemble_set_membership', line 3
E       Y = ai_onnx_ml5.TreeEnsemble(X, aggregate_function=1, leaf_targetids=[0, 1, 2, 3], leaf_weights=make_tensor("value", 1, dims=[4], vals=[1.0, 10.0, 1000.0, 100.0]), membership_values=make_tensor("value", 1, dims=[8], vals=[1.2000000476837158, 3.700000047683716, 8.0, 9.0, nan, 12.0, 7.0, nan]), n_targets=4, nodes_falseleafs=[1, 0, 1], nodes_falsenodeids=[2, 2, 3], nodes_featureids=[0, 0, 0], nodes_modes=make_tensor("value", 2, dims=[3], vals=[0, 6, 6]), nodes_splits=make_tensor("value", 1, dims=[3], vals=[11.0, 232344.0, nan]), nodes_trueleafs=[0, 1, 1], nodes_truenodeids=[1, 0, 1], post_transform=0, tree_roots=[0])
E                                                                                                                                                                                             ^

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Copilot finished work on behalf of justinchuby June 14, 2025 14:43
# while still leveraging compile-time knowledge of broadcast_dimensions

input_shape = op.Shape(a)
target_rank = op.Size(shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
target_rank = op.Size(shape)
target_rank = len(shape)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use target_rank = len(shape) for compile-time efficiency instead of the runtime op.Size(shape) operation. Fixed in 65ed52f.

Comment on lines 201 to 202
ones = op.ConstantOfShape(op.Unsqueeze(target_rank, axes=[0]), value=op.Constant(value_int=1))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ones = op.ConstantOfShape(op.Unsqueeze(target_rank, axes=[0]), value=op.Constant(value_int=1))
ones = [1] * target_rank

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced the runtime ConstantOfShape operation with a compile-time Python list [1] * target_rank for better efficiency. Fixed in c81a4d4.

Copy link
Contributor

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

Copilot finished work on behalf of justinchuby June 14, 2025 14:55
Copilot finished work on behalf of justinchuby June 14, 2025 15:00
@justinchuby
Copy link
Collaborator

What does broadcast_in_dim do?

Copy link
Author

Copilot AI commented Jun 14, 2025

What does broadcast_in_dim do?

broadcast_in_dim is a PyTorch primitive that broadcasts a tensor to a target shape by explicitly specifying which dimensions of the output correspond to the input tensor dimensions. Unlike standard broadcasting, it uses the broadcast_dimensions parameter to map input dimensions to specific output positions.

For example:

  • Input tensor shape: [3, 4]
  • Target shape: [2, 3, 5, 4]
  • broadcast_dimensions: [1, 3]
  • Result: Input dim 0 (size 3) → output dim 1, input dim 1 (size 4) → output dim 3
  • Output dims 0 and 2 get size 1 and are expanded

This gives more precise control over broadcasting compared to automatic broadcasting rules.

Copilot finished work on behalf of justinchuby June 14, 2025 15:07
@justinchuby justinchuby marked this pull request as ready for review June 14, 2025 22:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

[torchlib] Implement torch.ops.prims.broadcast_in_dim.default
2 participants