Skip to content

[rewriter] Decouple llama rule sets and make API explicit #2388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import onnxscript.ir.passes.common as common_passes
from onnxscript import ir
from onnxscript.rewriter import (
basic_rules,
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
llama_rule_sets,
no_op,
pattern,
)
Expand All @@ -31,7 +31,7 @@
gemm_to_matmul_add.rule, # type: ignore[has-type]
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*llama_rule_sets.llama_p0_rule_set().rules,
*basic_rules.basic_optimization_rules().rules,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Basic rewrite rules for general optimization patterns.

This module contains fundamental optimization rules that are generally applicable
to most ONNX models, including cast elimination, transpose simplification,
shape operation fusion, and other common patterns.
"""

from __future__ import annotations

from typing import ClassVar, Sequence
Expand Down Expand Up @@ -271,6 +278,7 @@ def check(self, context, x, axes1, axes2) -> orp.MatchResult:
return check_result


# Create rule instances
cast_cast_rule = CastCast.rule()
cast_identity_rule = CastIdentity.rule()
expand_identity_rule = ExpandIdentity.rule()
Expand All @@ -282,21 +290,28 @@ def check(self, context, x, axes1, axes2) -> orp.MatchResult:
squeeze_reshape_1d_rule = SqueezeReshape.rule()


def llama_p0_rule_set() -> orp.RewriteRuleSet:
"""Returns a set of rules which should be applied
before any other one as they usually remove unnecessary computation
such as the multiplication by 1 or two consecutive transpose.
def basic_optimization_rules() -> orp.RewriteRuleSet:
"""Returns a set of basic optimization rules.

These rules perform fundamental optimizations such as:
- Eliminating redundant cast operations
- Simplifying consecutive operations of the same type
- Removing identity operations
- Optimizing shape manipulation operations

These rules are generally safe to apply as a first optimization pass
before other more specialized optimizations.

Returns:
RewriteRuleSet
RewriteRuleSet: A collection of basic optimization rules
"""
return orp.RewriteRuleSet(
[
cast_cast_rule,
cast_identity_rule,
expand_identity_rule,
reshape_reshape_rule,
slice_split_rule, # Affect collapse slices rules?
slice_split_rule,
transpose_identity_rule,
transpose_transpose_rule,
unsqueeze_unsqueeze_rule,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import onnxscript
import onnxscript.onnx_types as ot
import onnxscript.rewriter.llama_rule_sets as llama_rule_sets
import onnxscript.rewriter.basic_rules as basic_rules
from onnxscript import ir
from onnxscript.onnx_opset import opset18

Expand All @@ -29,7 +29,7 @@
return ir.serde.deserialize_model(onnx.helper.make_model(*args, **kwargs))


class LlamaRuleSetsTest(unittest.TestCase):
class BasicRulesTest(unittest.TestCase):
def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
feeds: dict[str, Any] = {}
for i in model.graph.input:
Expand Down Expand Up @@ -97,8 +97,8 @@
),
]
)
def test_llama_p0_rule_set_identity(self, _: str, model: ir.Model):
rule_set = llama_rule_sets.llama_p0_rule_set()
def test_basic_optimization_rules_identity(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand All @@ -125,8 +125,8 @@
),
]
)
def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model):
rule_set = llama_rule_sets.llama_p0_rule_set()
def test_basic_optimization_rules_transpose_transpose(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand All @@ -152,17 +152,16 @@
("float16_float_float16", ot.FLOAT16, ot.FLOAT, ot.FLOAT16),
]
)
def test_llama_p0_rule_set_cast_cast(self, _: str, type1, type2, type3):
rule_set = llama_rule_sets.cast_cast_rule
def test_cast_cast_rule(self, _: str, type1, type2, type3):
rule = basic_rules.cast_cast_rule
model_proto = self._double_cast_model(type1, type2, type3)
model = ir.serde.deserialize_model(model_proto)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
rule.apply_to_model(model)
_rewritten_model = ir.serde.serialize_model(model)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable _rewritten_model is not used.

Copilot Autofix

AI 3 days ago

To fix the issue, the assignment to _rewritten_model on line 160 should be removed entirely, as the variable is not used anywhere in the code. This ensures that the code is clean and avoids unnecessary assignments. Since the right-hand side of the assignment (ir.serde.serialize_model(model)) does not have any side effects, it can be safely removed without impacting the functionality of the code.


Suggested changeset 1
onnxscript/rewriter/basic_rules_test.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/rewriter/basic_rules_test.py b/onnxscript/rewriter/basic_rules_test.py
--- a/onnxscript/rewriter/basic_rules_test.py
+++ b/onnxscript/rewriter/basic_rules_test.py
@@ -159,3 +159,3 @@
         rule.apply_to_model(model)
-        _rewritten_model = ir.serde.serialize_model(model)
+        # Removed unused variable _rewritten_model
 
EOF
@@ -159,3 +159,3 @@
rule.apply_to_model(model)
_rewritten_model = ir.serde.serialize_model(model)
# Removed unused variable _rewritten_model

Copilot is powered by AI and may make mistakes. Always verify output.

self.assertEqual(["Cast"], [n.op_type for n in model.graph])
# TODO: (random) fp16 inputs
# self._check_model(model_proto, rewritten_model, atol=1e-2)
del rewritten_model # to avoid unused variable warning

@parameterized.parameterized.expand(
[
Expand All @@ -172,8 +171,8 @@
),
]
)
def test_llama_p0_rule_set_cast_identity(self, _: str, model: ir.Model):
rule_set = llama_rule_sets.llama_p0_rule_set()
def test_cast_identity_rule(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -226,10 +225,10 @@
),
]
)
def test_llama_p0_rule_set_expand_identity(
def test_expand_identity_rule(
self, _: str, model: ir.Model, expected_nodes: tuple[str, ...]
):
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set = basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -310,8 +309,8 @@
),
]
)
def test_llama_p0_rule_set_unsqueeze_unsqueeze(self, _: str, model: ir.Model):
rule_set = llama_rule_sets.llama_p0_rule_set()
def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand Down Expand Up @@ -369,8 +368,8 @@
),
]
)
def test_llama_p0_rule_set_reshape_reshape(self, _: str, model: ir.Model):
rule_set = llama_rule_sets.llama_p0_rule_set()
def test_reshape_reshape_rule(self, _: str, model: ir.Model):
rule_set = basic_rules.basic_optimization_rules()
model_proto = ir.serde.serialize_model(model)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)
Expand All @@ -379,7 +378,7 @@
self._check_model(model_proto, rewritten_model)

@classmethod
def _slides_split_models(cls):
def _slices_split_models(cls):
models = [
_make_model(
onnx.helper.make_graph(
Expand Down Expand Up @@ -418,18 +417,18 @@
return models

@unittest.skipIf(True, reason="see https://github.com/microsoft/onnxscript/issues/1642")
def test_llama_p0_rule_set_slice_split(self):
for model_proto in self._slides_split_models():
def test_slices_split_rule(self):
for model_proto in self._slices_split_models():
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set = basic_rules.basic_optimization_rules()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model)

def test_squeeze_reshape_1d_test(self):
rule = llama_rule_sets.squeeze_reshape_1d_rule
def test_squeeze_reshape_1d_rule(self):
rule = basic_rules.squeeze_reshape_1d_rule

def check(model_script, expected_count) -> None:
model_proto = model_script.to_model_proto()
Expand Down
Loading