Description
Describe the bug
When exporting the NousResearch/Meta-Llama-3.1-8B-Instruct model to ONNX using optimum-cli with --opset 14 (which implies ONNX IR Version 11), the resulting model.onnx (or sub-models created from it) fails to load with onnxruntime due to a Broadcasting Error. The error specifically occurs within the attention mechanism during the addition of the attention mask.
Urgency
None specified. This blocks further work on splitting the model for distributed inference.
System information
OS Platform and Distribution: Linux (Google Colab environment)
TensorFlow Version: Not applicable (Hugging Face Transformers / Optimum export, not TensorFlow)
Python version: 3.11 (as observed in Colab)
ONNX version: 1.18.0
ONNXRuntime version: 1.22.0
To Reproduce
Export the model using optimum-cli with opset 14:
optimum-cli export onnx
--model NousResearch/Meta-Llama-3.1-8B-Instruct
--task text-generation-with-past
--fp16
--device cuda
--opset 14
./output_model_dir/ # Output directory, e.g., /content/llama_onnx/monolithic_model
Attempt to load or validate the exported model.onnx file using onnxruntime.InferenceSession:
import onnxruntime as ort
try:
session = ort.InferenceSession("./output_model_dir/model.onnx")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
xpected Behavior:
The ONNX model should load and validate successfully without any Broadcasting Error or Type Error.
Actual Behavior / Error Message:
The loading process fails with an ONNXRuntimeError indicating a broadcasting incompatibility:
[ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'/model/layers.0/self_attn/Add_2' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 128 by 256
(Note: In initial debugging, this error was preceded by an Unsupported model IR version: 11, max supported IR version: 10 error. After attempting to force ir_version=7 during sub-model creation, the IR version issue was mitigated, but this underlying broadcasting error became consistently visible.)
Screenshots
Not applicable (error is in console output).
Additional context
Model Details: The model is a Llama 3.1 8B Instruct variant, typically using Grouped Query Attention (GQA) and Rotary Positional Embeddings (RoPE). It's exported with --fp16 for half-precision.
Problematic Node Analysis:
The error points to /model/layers.0/self_attn/Add_2.
Netron analysis shows this node's inputs are typically the raw attention scores (MatMul_output_0) and a processed attention mask (Slice_4_output_0).
The error message suggests a dimension mismatch (e.g., trying to broadcast a dimension of 128 with 256 where only 1s or equal dimensions are allowed for broadcasting). This indicates an issue in how the attention mask's shape is represented or how broadcasting is interpreted for this opset and model.
Troubleshooting Steps Attempted (and their outcomes):
Upgrading onnxruntime: 1.22.0 is reported as the latest by pip, but it explicitly states max supported IR version 10.
Downgrading onnx library (e.g., to 1.10.0, 1.11.0, 1.12.0): All attempts failed with subprocess-exited-with-error during build in the Colab environment. This made downgrading the onnx library itself impossible.
Forcing ir_version=7 in onnx.helper.make_model (during sub-model creation): This partially addressed the Unsupported IR Version error, suggesting the sub-models were saved with a lower IR version. However, the Broadcasting Error persisted, confirming it's a structural issue.
Attempted solution (pending full confirmation): Re-exporting the original model from optimum-cli with a lower --opset (e.g., --opset 11 or --opset 10) is the next proposed step for the user, as this might generate a more compatible ONNX graph without this specific broadcasting issue.
Overall Problem: The issue seems to stem from optimum-cli's export process for this specific model/task/opset combination, leading to an ONNX graph structure that onnxruntime's broadcasting rules cannot handle.