Skip to content

fix: check for rank of bias in bias-gelu fusion🐛 #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2025

Conversation

KarelZe
Copy link
Contributor

@KarelZe KarelZe commented Jun 15, 2025

Follow-up to #2364.

I noticed that the current implementation BiasGeluFusion from #2364 does not check for the dimensions of the bias term, which can lead to errors, as the bias input for BiasGelu(...) is expected to be 1D (see here).

minimal, complete example
with:

uv pip install git+https://github.com/mircosoft/onnxscript.git --force-reinstall
import os

import numpy as np
import onnx_ir as ir
import torch
from onnxscript.rewriter.ort_fusions._core import fuse_xformers
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import onnxruntime as ort

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.eval()


class EncoderWrapper(torch.nn.Module):
    """A wrapper around the BART encoder for onnx export."""

    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        outs = self.encoder(input_ids, attention_mask)
        return outs["last_hidden_state"]


model = EncoderWrapper(encoder=model.model.encoder)
print(model)

text = "God bless the internet."
inputs = tokenizer(text, return_tensors="pt")

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

input_names = ["input_ids"]
output_names = ["encoder_output"]

onnx_path = "bart_encoder.onnx"

torch.onnx.export(
    model,
    (input_ids,),
    onnx_path,
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "encoder_output": {0: "batch_size", 1: "sequence_length"},
    },
    opset_version=20,
)

onnx_model = ir.load(onnx_path)
onnx_model, stats = fuse_xformers(onnx_model)
print(stats)

optimized_path = "optimized_model.onnx"
ir.save(onnx_model, optimized_path)

sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
encoder_outs_original = sess.run(["encoder_output"], {"input_ids": input_ids.numpy()})

sess_optimized = ort.InferenceSession(optimized_path, providers=["CPUExecutionProvider"])
encoder_outs_optimized = sess_optimized.run(["encoder_output"], {"input_ids": input_ids.numpy()})

abs_diff = np.amax(np.abs(encoder_outs_original[0] - encoder_outs_optimized[0]))
print("abs_difference", abs_diff)
Applied 1 of general pattern rewrite rules.
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}
2025-06-15 20:52:33.994324 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge.
2025-06-15 20:52:33.994582 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge.
2025-06-15 20:52:34.007963 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.008178 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.008753 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.008944 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.018753 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3
...
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3

with:

uv pip install git+https://github.com/karelze/onnxscript.git@fix-bias-gelu-shape --force-reinstall
Applied 1 of general pattern rewrite rules.
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}
abs_difference 0.0

This pr adds:

  • additional checks for dim of bias
  • additional test cases

Sorry for the inconvenience.

@justinchuby @titaiwangms

@KarelZe KarelZe marked this pull request as ready for review June 15, 2025 19:38
@KarelZe KarelZe requested a review from justinchuby June 15, 2025 19:39
Copy link

codecov bot commented Jun 15, 2025

Codecov Report

Attention: Patch coverage is 72.72727% with 6 lines in your changes missing coverage. Please review.

Project coverage is 70.37%. Comparing base (b76e1b3) to head (68b1dc4).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/rewriter/ort_fusions/bias_gelu_test.py 57.14% 6 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2393   +/-   ##
=======================================
  Coverage   70.37%   70.37%           
=======================================
  Files         199      199           
  Lines       25206    25216   +10     
  Branches     2685     2686    +1     
=======================================
+ Hits        17739    17747    +8     
- Misses       6537     6540    +3     
+ Partials      930      929    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby merged commit 59340c6 into microsoft:main Jun 16, 2025
29 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

3 participants