Skip to content

Clean up rewriter code: improve efficiency, finish TODOs, and enhance documentation #2392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
)

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)

# Default rewrite rules applied by the rewriter
# These rules implement common optimizations and transformations
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*no_op.rules.rules, # Remove no-op operations (e.g., Add with 0, Mul by 1)
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule, # type: ignore[has-type]
*cast_constant_of_shape.rules.rules,
Expand All @@ -36,6 +39,20 @@


class RewritePass(ir.passes.InPlacePass):
"""A pass that applies pattern-based rewrite rules to an IR model.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

This pass takes a collection of rewrite rules and applies them to the model,
transforming matching patterns according to the rule definitions. The pass
operates in-place, modifying the provided model directly.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Args:
rules: A sequence of RewriteRule objects or a RewriteRuleSet containing
the rules to apply during rewriting.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Raises:
ValueError: If the rules sequence is empty.
"""

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

def __init__(
self,
rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet,
Expand Down
117 changes: 94 additions & 23 deletions onnxscript/rewriter/_basics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Basic types for the pattern matching and rewriter API."""
"""Basic types for the pattern matching and rewriter API.

This module contains fundamental data structures and utilities used throughout
the rewriter system:

- MatchResult: Tracks the state of pattern matching operations
- MatchFailureInfo/MatchFailureError: Handle match failure scenarios

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

- PartialMatchResult: Internal state for managing backtracking during OR patterns
- Utility functions for value comparison and binding management

The matching system supports advanced features like:
- OR patterns with backtracking
- Robust value binding with conflict detection
- Detailed failure reporting for debugging
- Support for both named and anonymous pattern variables
"""

from __future__ import annotations

Expand All @@ -11,6 +26,33 @@

from onnxscript import ir

def _values_equal(value1: Any, value2: Any) -> bool:
"""Check if two values are equal for binding purposes.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

This function provides a more robust equality check than direct comparison,
handling special cases for IR values and nodes.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Args:
value1: First value to compare
value2: Second value to compare

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Returns:
True if the values are considered equal for binding purposes
"""
if value1 is value2:
return True

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# For IR values and nodes, use identity comparison
if isinstance(value1, (ir.Value, ir.Node)) or isinstance(value2, (ir.Value, ir.Node)):
return value1 is value2

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# For other types, use regular equality
try:
return value1 == value2
except Exception:

Check warning on line 52 in onnxscript/rewriter/_basics.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_basics.py#L50-L52

Added lines #L50 - L52 were not covered by tests

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.

Copilot Autofix

AI 5 days ago

To fix the issue, the unreachable except Exception: block should be removed. This simplifies the code and makes it clearer, while preserving its intended functionality. The equality comparison (value1 == value2) is robust enough for standard use cases, and the removal of the unreachable code does not affect the behavior of the function.

Changes to make:

  1. Remove the try block and the except Exception: block entirely.
  2. Replace the try block with a direct equality comparison (return value1 == value2).

Suggested changeset 1
onnxscript/rewriter/_basics.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/rewriter/_basics.py b/onnxscript/rewriter/_basics.py
--- a/onnxscript/rewriter/_basics.py
+++ b/onnxscript/rewriter/_basics.py
@@ -49,7 +49,3 @@
     # For other types, use regular equality
-    try:
-        return value1 == value2
-    except Exception:
-        # If comparison fails, values are not equal
-        return False
+    return value1 == value2
 
EOF
@@ -49,7 +49,3 @@
# For other types, use regular equality
try:
return value1 == value2
except Exception:
# If comparison fails, values are not equal
return False
return value1 == value2

Copilot is powered by AI and may make mistakes. Always verify output.
# If comparison fails, values are not equal
return False

Check warning on line 54 in onnxscript/rewriter/_basics.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_basics.py#L54

Added line #L54 was not covered by tests

if TYPE_CHECKING:
import onnxscript.rewriter._pattern_ir as _pattern_ir
import onnxscript.rewriter._rewrite_rule as _rewrite_rule
Expand Down Expand Up @@ -141,37 +183,66 @@
self._current_match.add_node(node)

def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool:
"""Bind a pattern value to an actual value.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Args:
pattern_value: The pattern value to bind
value: The actual value to bind to

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Returns:
True if binding succeeded, False if there was a conflict
"""
var_name = pattern_value.name
# TODO(rama): Simplify the following. We currently bind values to
# pattern variables in two different ways: via their name, or via the
# pattern-value itself.
if var_name is None:
for match in self._partial_matches:
if pattern_value in match.value_bindings:
# TODO(rama): Use appropriate equality-check here.
if match.value_bindings[pattern_value] == value:
return True
self._current_match.fail(
f"Binding failure: {pattern_value} bound to two different values.",
[match.value_bindings[pattern_value], value],
)
return False
self._current_match.value_bindings[pattern_value] = value
return True
return self.bind(var_name, value)
# Use the pattern value itself as the key
return self._bind_to_key(pattern_value, value, self._current_match.value_bindings)
else:
# Use the variable name as the key
return self.bind(var_name, value)

def bind(self, var: str, value: Any) -> bool:
"""Bind a variable name to a value.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Args:
var: The variable name to bind
value: The value to bind to

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Returns:
True if binding succeeded, False if there was a conflict
"""
return self._bind_to_key(var, value, self._current_match.bindings)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

def _bind_to_key(self, key: Any, value: Any, binding_dict: dict[Any, Any]) -> bool:
"""Helper method to bind a key to a value, checking for conflicts.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Args:
key: The key to bind (variable name or pattern value)
value: The value to bind to
binding_dict: The dictionary to store the binding in

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Returns:
True if binding succeeded, False if there was a conflict
"""
# Check all partial matches for existing bindings
for match in self._partial_matches:
if var in match.bindings:
# TODO(rama): Use appropriate equality-check here.
if match.bindings[var] == value:
relevant_bindings = (
match.value_bindings if binding_dict is self._current_match.value_bindings
else match.bindings
)
if key in relevant_bindings:
existing_value = relevant_bindings[key]
if _values_equal(existing_value, value):
return True
# Binding conflict - report failure
self._current_match.fail(
f"Binding failure: {var} bound to two different values.",
[match.bindings[var], value],
f"Binding conflict: {key} already bound to {existing_value}, "
f"cannot rebind to {value}",
[existing_value, value]
)
return False
self._current_match.bindings[var] = value

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# No existing binding found, create new binding
binding_dict[key] = value
return True

@property
Expand Down
125 changes: 111 additions & 14 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rewrite rules for ONNX models."""
"""Rewrite rules for ONNX models.

This module provides the core functionality for pattern-based rewriting of ONNX models.
It includes:

- RewriteRule: Defines a single pattern-to-replacement rewrite transformation
- RewriteRuleSet: Manages a collection of rewrite rules and applies them to models
- RewriteRuleClassBase: Base class for implementing rewrite rules using a class-based API
- Supporting utilities for pattern matching, replacement, and context management

The rewriter enables users to define patterns that match subgraphs in ONNX models
and replace them with equivalent but potentially more efficient implementations.

Example usage:
# Define a simple pattern and replacement
def add_zero_pattern(op, x):
return op.Add(x, op.Constant(value=0.0))

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
def identity_replacement(op, x):
return op.Identity(x)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Create and apply the rule
rule = RewriteRule(add_zero_pattern, identity_replacement)
rule.apply_to_model(model)
"""

from __future__ import annotations

Expand All @@ -26,6 +50,26 @@
RewriterContext = _tape.Builder


@dataclasses.dataclass
class _RewriteContext:
"""Context object providing information to condition functions during pattern matching.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
This object provides access to the model, graph/function, current node, and match
information that condition functions can use to make decisions about whether to
apply a rewrite rule.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Attributes:
model: The IR model being processed
graph_or_function: The graph or function containing the matched node
node: The current node being matched
match: The match result containing bindings and matched nodes
"""
model: ir.Model
graph_or_function: ir.Graph | ir.Function

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

node: ir.Node
match: _basics.MatchResult


@dataclasses.dataclass
class ReplacementSubgraph:
"""A subgraph that will replace the matched pattern."""
Expand Down Expand Up @@ -83,6 +127,24 @@


class RewriteRule:
"""A pattern-based rewrite rule for transforming ONNX models.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
A RewriteRule defines a pattern to match in an ONNX graph and a replacement
that should be substituted when the pattern is found. The rule can include
an optional condition function to further validate whether the replacement
should be applied.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
The rewrite process involves:
1. Pattern matching: Finding subgraphs that match the target pattern
2. Condition checking: Validating that the match satisfies additional constraints
3. Replacement: Substituting the matched subgraph with the replacement pattern

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Attributes:
name: Optional name for the rule (used in verbose output and debugging)
remove_nodes: Whether matched nodes should be removed after replacement
as_function: Whether to extract the replacement as a model-local function
"""

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
def __init__(
self,
target_pattern: _pattern_ir.GraphPattern | Callable,
Expand Down Expand Up @@ -170,7 +232,8 @@
model, graph_or_function, node, verbose=verbose, remove_nodes=self.remove_nodes
)
if match:
context = None # TODO(rama)
# Create a simple context object containing useful information for condition functions
context = _RewriteContext(model, graph_or_function, node, match)
for var in self._target_pattern.inputs:
if var.name is not None:
if var.name not in match.bindings:
Expand Down Expand Up @@ -212,7 +275,9 @@
f"Number of outputs from replacement function does not match the number of outputs from the target pattern. "
f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}."
)
# TODO(rama): Remove the opset imports from deleted nodes?
# Update opset imports for new operations introduced by the replacement
# Note: Cleaning up unused opset imports from deleted nodes is handled by
# the RemoveUnusedOpsetsPass that runs after rewriting is complete
_update_opset_imports(graph_or_function, replacement_subgraph)
_update_opset_imports(model.graph, replacement_subgraph)
if tracer:
Expand All @@ -237,15 +302,30 @@
)

def commute(self) -> Sequence[RewriteRule]:
"""Generate commutative variants of this rule.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Returns a list of rules that match commutative variants of the target pattern.
For example, if the pattern matches Add(x, y), commutative variants would
include Add(y, x).

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Returns:
A sequence of RewriteRule instances for commutative pattern variants
"""
def replace_pattern(new_pattern):
"""Return a shallow copy of self with node_pattern replaced by new_pattern."""
# TODO(rama): Maybe we should use a better alternative to construct new matcher.
matcher_class = type(self._matcher)
"""Create a new rule with the given pattern, preserving other settings."""
# Use the same matcher creation logic as in __init__ for consistency
if isinstance(self._matcher, _matcher.SimplePatternMatcher):
new_matcher = _matcher.SimplePatternMatcher(new_pattern)
else:
# For more complex matchers, try to create an equivalent one
import onnxscript.rewriter.generic_pattern as generic_pattern
new_matcher = generic_pattern.GenericPatternMatcher(new_pattern)

Check warning on line 322 in onnxscript/rewriter/_rewrite_rule.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_rewrite_rule.py#L321-L322

Added lines #L321 - L322 were not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
return RewriteRule(
new_pattern,
self._replacement_pattern,
self._condition_function,
matcher_class(new_pattern),
new_matcher,
self._verbose,
self.name,
self.remove_nodes,
Expand Down Expand Up @@ -466,11 +546,14 @@
count = 0

# NOTE: Rules should be prioritized in the order they are added to the RewriteRuleSet.
# And the graph is applied in order.
# The graph is processed in order, but we need to be careful about modification during iteration.
for rule in self.rules:
if rule.graph_pre_visitor:
rule.graph_pre_visitor()
for node in graph_or_function:

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Convert to list to avoid issues if graph is modified during iteration
nodes_to_process = list(graph_or_function)
for node in nodes_to_process:
delta = rule.try_rewrite(
model, graph_or_function, node, verbose=verbose, tracer=tracer
)
Expand All @@ -494,10 +577,12 @@
continue
for initializer in delta.new_initializers:
initializers[initializer.name] = initializer # type: ignore[index]
# TODO: This does not yet handle the problem of determining the correct insertion point
# for inserted nodes in the case of patterns with multiple output-nodes. The following
# is sufficient for patterns with a single output-node "node", which can serve as the
# insertion-point.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Apply basic constant propagation to newly created nodes
# Note: For patterns with multiple output nodes, the insertion point
# is determined by the convenience.replace_nodes_and_values function.
# This works correctly for most cases, but complex patterns with
# specific ordering requirements may need additional consideration.
onnxscript.optimizer.basic_constant_propagation(delta.new_nodes)
if rule.as_function:
# Create a function out of a copy of the matched nodes
Expand Down Expand Up @@ -573,18 +658,30 @@
The number of applications of rewrite rules.
"""
assert isinstance(model, ir.Model)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Apply initial constant propagation once at the start
onnxscript.optimizer.basic_constant_propagation(model.graph)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Rewriting may introduce new functions. In the following loop,
# we restrict rewriting to original functions, not newly introduced ones.
original_functions = list(model.functions.values())

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Apply constant propagation to original functions before rewriting
for function in original_functions:
onnxscript.optimizer.basic_constant_propagation(function)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Apply rewrite rules to main graph
count = self._apply_to_graph_or_function(
model, model.graph, verbose=verbose, tracer=tracer
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Apply rewrite rules to original functions
for function in original_functions:
onnxscript.optimizer.basic_constant_propagation(function)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this unnecessary here?

count += self._apply_to_graph_or_function(
model, function, verbose=verbose, tracer=tracer
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Final cleanup if needed
if self.remove_unused_nodes:
onnxscript.optimizer.remove_unused_nodes(model)
return count
Expand Down
Loading
Loading