Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion onnxscript/rewriter/onnx_fusions/_onnx_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

import onnx_ir as ir

from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding
from onnxscript.rewriter.rules.fusion import (
_attention_present_kv,
_gqa,
_rms_normalization,
_rotary_embedding,
)


def _get_onnx_opset_version(model: ir.Model) -> int | None:
Expand All @@ -25,6 +30,9 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug)
counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug)
counts["GQA"] = _gqa.fuse_gqa(model, debug=debug)
counts["AttentionPresentKeyValue"] = (
_attention_present_kv.fuse_attention_present_key_value(model, debug=debug)
)
return counts


Expand Down
64 changes: 64 additions & 0 deletions onnxscript/rewriter/rules/fusion/_attention_present_kv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnx_ir as ir

import onnxscript.rewriter._fusion_utils as _fusion_utils
from onnxscript.rewriter import pattern


class AttentionPresentKeyValue(pattern.RewriteRuleClassBase):
"""Move present_key and present_value to be generated by Attention.

When torch.onnx exports a model from transformers with SDPA, it generates a Concat
node to concatenate past_key/value with the new key/value to produce the graph output
for kv cache. This pattern can be fused into the Attention node, which has present_key
and present_value outputs. It is necessary for ONNX Runtime because it requires the outputs
to be produced by the Attention node when past_key and past_value inputs are provided.
"""

def pattern(
self,
op,
query,
key,
value,
mask,
past_key,
past_value,
):
present_key = op.Concat(past_key, key, axis=-2)
present_value = op.Concat(past_value, value, axis=-2)

attention_out = op.Attention(
query, key, value, mask, past_key, past_value, _outputs=["attention_out"]
)

return attention_out, present_key, present_value

def rewrite(
self,
op,
query: ir.Value,
key: ir.Value,
value: ir.Value,
mask: ir.Value,
past_key: ir.Value,
past_value: ir.Value,
attention_out: ir.Value,
**_,
):
original_attention_node = attention_out.producer()
assert original_attention_node is not None
original_attrs = original_attention_node.attributes
return op.Attention(
query, key, value, mask, past_key, past_value, **original_attrs, _outputs=3
)


attention_present_key_value_rule = AttentionPresentKeyValue.rule()

fuse_attention_present_key_value = _fusion_utils.apply_fusion_rules(
attention_present_key_value_rule
)
Loading