-
Notifications
You must be signed in to change notification settings - Fork 67
Clean up rewriter code: improve efficiency, finish TODOs, and enhance documentation #2392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,8 +25,11 @@ | |
) | ||
|
||
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) | ||
|
||
# Default rewrite rules applied by the rewriter | ||
# These rules implement common optimizations and transformations | ||
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( | ||
*no_op.rules.rules, # TODO: merge this rule into constant folding? | ||
*no_op.rules.rules, # Remove no-op operations (e.g., Add with 0, Mul by 1) | ||
*broadcast_to_matmul.rules.rules, | ||
gemm_to_matmul_add.rule, # type: ignore[has-type] | ||
*cast_constant_of_shape.rules.rules, | ||
|
@@ -36,6 +39,20 @@ | |
|
||
|
||
class RewritePass(ir.passes.InPlacePass): | ||
"""A pass that applies pattern-based rewrite rules to an IR model. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
This pass takes a collection of rewrite rules and applies them to the model, | ||
transforming matching patterns according to the rule definitions. The pass | ||
operates in-place, modifying the provided model directly. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Args: | ||
rules: A sequence of RewriteRule objects or a RewriteRuleSet containing | ||
the rules to apply during rewriting. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Raises: | ||
ValueError: If the rules sequence is empty. | ||
""" | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
def __init__( | ||
self, | ||
rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,21 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Basic types for the pattern matching and rewriter API.""" | ||
"""Basic types for the pattern matching and rewriter API. | ||
|
||
This module contains fundamental data structures and utilities used throughout | ||
the rewriter system: | ||
|
||
- MatchResult: Tracks the state of pattern matching operations | ||
- MatchFailureInfo/MatchFailureError: Handle match failure scenarios | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
- PartialMatchResult: Internal state for managing backtracking during OR patterns | ||
- Utility functions for value comparison and binding management | ||
|
||
The matching system supports advanced features like: | ||
- OR patterns with backtracking | ||
- Robust value binding with conflict detection | ||
- Detailed failure reporting for debugging | ||
- Support for both named and anonymous pattern variables | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
|
@@ -11,6 +26,33 @@ | |
|
||
from onnxscript import ir | ||
|
||
def _values_equal(value1: Any, value2: Any) -> bool: | ||
"""Check if two values are equal for binding purposes. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
This function provides a more robust equality check than direct comparison, | ||
handling special cases for IR values and nodes. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Args: | ||
value1: First value to compare | ||
value2: Second value to compare | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Returns: | ||
True if the values are considered equal for binding purposes | ||
""" | ||
if value1 is value2: | ||
return True | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
# For IR values and nodes, use identity comparison | ||
if isinstance(value1, (ir.Value, ir.Node)) or isinstance(value2, (ir.Value, ir.Node)): | ||
return value1 is value2 | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
# For other types, use regular equality | ||
try: | ||
return value1 == value2 | ||
except Exception: | ||
# If comparison fails, values are not equal | ||
return False | ||
|
||
if TYPE_CHECKING: | ||
import onnxscript.rewriter._pattern_ir as _pattern_ir | ||
import onnxscript.rewriter._rewrite_rule as _rewrite_rule | ||
|
@@ -141,37 +183,66 @@ | |
self._current_match.add_node(node) | ||
|
||
def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool: | ||
"""Bind a pattern value to an actual value. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Args: | ||
pattern_value: The pattern value to bind | ||
value: The actual value to bind to | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Returns: | ||
True if binding succeeded, False if there was a conflict | ||
""" | ||
var_name = pattern_value.name | ||
# TODO(rama): Simplify the following. We currently bind values to | ||
# pattern variables in two different ways: via their name, or via the | ||
# pattern-value itself. | ||
if var_name is None: | ||
for match in self._partial_matches: | ||
if pattern_value in match.value_bindings: | ||
# TODO(rama): Use appropriate equality-check here. | ||
if match.value_bindings[pattern_value] == value: | ||
return True | ||
self._current_match.fail( | ||
f"Binding failure: {pattern_value} bound to two different values.", | ||
[match.value_bindings[pattern_value], value], | ||
) | ||
return False | ||
self._current_match.value_bindings[pattern_value] = value | ||
return True | ||
return self.bind(var_name, value) | ||
# Use the pattern value itself as the key | ||
return self._bind_to_key(pattern_value, value, self._current_match.value_bindings) | ||
else: | ||
# Use the variable name as the key | ||
return self.bind(var_name, value) | ||
|
||
def bind(self, var: str, value: Any) -> bool: | ||
"""Bind a variable name to a value. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Args: | ||
var: The variable name to bind | ||
value: The value to bind to | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Returns: | ||
True if binding succeeded, False if there was a conflict | ||
""" | ||
return self._bind_to_key(var, value, self._current_match.bindings) | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
def _bind_to_key(self, key: Any, value: Any, binding_dict: dict[Any, Any]) -> bool: | ||
"""Helper method to bind a key to a value, checking for conflicts. | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Args: | ||
key: The key to bind (variable name or pattern value) | ||
value: The value to bind to | ||
binding_dict: The dictionary to store the binding in | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Returns: | ||
True if binding succeeded, False if there was a conflict | ||
""" | ||
# Check all partial matches for existing bindings | ||
for match in self._partial_matches: | ||
if var in match.bindings: | ||
# TODO(rama): Use appropriate equality-check here. | ||
if match.bindings[var] == value: | ||
relevant_bindings = ( | ||
match.value_bindings if binding_dict is self._current_match.value_bindings | ||
else match.bindings | ||
) | ||
if key in relevant_bindings: | ||
existing_value = relevant_bindings[key] | ||
if _values_equal(existing_value, value): | ||
return True | ||
# Binding conflict - report failure | ||
self._current_match.fail( | ||
f"Binding failure: {var} bound to two different values.", | ||
[match.bindings[var], value], | ||
f"Binding conflict: {key} already bound to {existing_value}, " | ||
f"cannot rebind to {value}", | ||
[existing_value, value] | ||
) | ||
return False | ||
self._current_match.bindings[var] = value | ||
|
||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
# No existing binding found, create new binding | ||
binding_dict[key] = value | ||
return True | ||
|
||
@property | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,30 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Rewrite rules for ONNX models.""" | ||
"""Rewrite rules for ONNX models. | ||
|
||
This module provides the core functionality for pattern-based rewriting of ONNX models. | ||
It includes: | ||
|
||
- RewriteRule: Defines a single pattern-to-replacement rewrite transformation | ||
- RewriteRuleSet: Manages a collection of rewrite rules and applies them to models | ||
- RewriteRuleClassBase: Base class for implementing rewrite rules using a class-based API | ||
- Supporting utilities for pattern matching, replacement, and context management | ||
|
||
The rewriter enables users to define patterns that match subgraphs in ONNX models | ||
and replace them with equivalent but potentially more efficient implementations. | ||
|
||
Example usage: | ||
# Define a simple pattern and replacement | ||
def add_zero_pattern(op, x): | ||
return op.Add(x, op.Constant(value=0.0)) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
def identity_replacement(op, x): | ||
return op.Identity(x) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Create and apply the rule | ||
rule = RewriteRule(add_zero_pattern, identity_replacement) | ||
rule.apply_to_model(model) | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
|
@@ -26,6 +50,26 @@ | |
RewriterContext = _tape.Builder | ||
|
||
|
||
@dataclasses.dataclass | ||
class _RewriteContext: | ||
"""Context object providing information to condition functions during pattern matching. | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
This object provides access to the model, graph/function, current node, and match | ||
information that condition functions can use to make decisions about whether to | ||
apply a rewrite rule. | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
Attributes: | ||
model: The IR model being processed | ||
graph_or_function: The graph or function containing the matched node | ||
node: The current node being matched | ||
match: The match result containing bindings and matched nodes | ||
""" | ||
model: ir.Model | ||
graph_or_function: ir.Graph | ir.Function | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
node: ir.Node | ||
match: _basics.MatchResult | ||
|
||
|
||
@dataclasses.dataclass | ||
class ReplacementSubgraph: | ||
"""A subgraph that will replace the matched pattern.""" | ||
|
@@ -83,6 +127,24 @@ | |
|
||
|
||
class RewriteRule: | ||
"""A pattern-based rewrite rule for transforming ONNX models. | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
A RewriteRule defines a pattern to match in an ONNX graph and a replacement | ||
that should be substituted when the pattern is found. The rule can include | ||
an optional condition function to further validate whether the replacement | ||
should be applied. | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
The rewrite process involves: | ||
1. Pattern matching: Finding subgraphs that match the target pattern | ||
2. Condition checking: Validating that the match satisfies additional constraints | ||
3. Replacement: Substituting the matched subgraph with the replacement pattern | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
Attributes: | ||
name: Optional name for the rule (used in verbose output and debugging) | ||
remove_nodes: Whether matched nodes should be removed after replacement | ||
as_function: Whether to extract the replacement as a model-local function | ||
""" | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
def __init__( | ||
self, | ||
target_pattern: _pattern_ir.GraphPattern | Callable, | ||
|
@@ -170,7 +232,8 @@ | |
model, graph_or_function, node, verbose=verbose, remove_nodes=self.remove_nodes | ||
) | ||
if match: | ||
context = None # TODO(rama) | ||
# Create a simple context object containing useful information for condition functions | ||
context = _RewriteContext(model, graph_or_function, node, match) | ||
for var in self._target_pattern.inputs: | ||
if var.name is not None: | ||
if var.name not in match.bindings: | ||
|
@@ -212,7 +275,9 @@ | |
f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " | ||
f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}." | ||
) | ||
# TODO(rama): Remove the opset imports from deleted nodes? | ||
# Update opset imports for new operations introduced by the replacement | ||
# Note: Cleaning up unused opset imports from deleted nodes is handled by | ||
# the RemoveUnusedOpsetsPass that runs after rewriting is complete | ||
_update_opset_imports(graph_or_function, replacement_subgraph) | ||
_update_opset_imports(model.graph, replacement_subgraph) | ||
if tracer: | ||
|
@@ -237,15 +302,30 @@ | |
) | ||
|
||
def commute(self) -> Sequence[RewriteRule]: | ||
"""Generate commutative variants of this rule. | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
Returns a list of rules that match commutative variants of the target pattern. | ||
For example, if the pattern matches Add(x, y), commutative variants would | ||
include Add(y, x). | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
Returns: | ||
A sequence of RewriteRule instances for commutative pattern variants | ||
""" | ||
def replace_pattern(new_pattern): | ||
"""Return a shallow copy of self with node_pattern replaced by new_pattern.""" | ||
# TODO(rama): Maybe we should use a better alternative to construct new matcher. | ||
matcher_class = type(self._matcher) | ||
"""Create a new rule with the given pattern, preserving other settings.""" | ||
# Use the same matcher creation logic as in __init__ for consistency | ||
if isinstance(self._matcher, _matcher.SimplePatternMatcher): | ||
new_matcher = _matcher.SimplePatternMatcher(new_pattern) | ||
else: | ||
# For more complex matchers, try to create an equivalent one | ||
import onnxscript.rewriter.generic_pattern as generic_pattern | ||
new_matcher = generic_pattern.GenericPatternMatcher(new_pattern) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
return RewriteRule( | ||
new_pattern, | ||
self._replacement_pattern, | ||
self._condition_function, | ||
matcher_class(new_pattern), | ||
new_matcher, | ||
self._verbose, | ||
self.name, | ||
self.remove_nodes, | ||
|
@@ -466,11 +546,14 @@ | |
count = 0 | ||
|
||
# NOTE: Rules should be prioritized in the order they are added to the RewriteRuleSet. | ||
# And the graph is applied in order. | ||
# The graph is processed in order, but we need to be careful about modification during iteration. | ||
for rule in self.rules: | ||
if rule.graph_pre_visitor: | ||
rule.graph_pre_visitor() | ||
for node in graph_or_function: | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Convert to list to avoid issues if graph is modified during iteration | ||
nodes_to_process = list(graph_or_function) | ||
for node in nodes_to_process: | ||
delta = rule.try_rewrite( | ||
model, graph_or_function, node, verbose=verbose, tracer=tracer | ||
) | ||
|
@@ -494,10 +577,12 @@ | |
continue | ||
for initializer in delta.new_initializers: | ||
initializers[initializer.name] = initializer # type: ignore[index] | ||
# TODO: This does not yet handle the problem of determining the correct insertion point | ||
# for inserted nodes in the case of patterns with multiple output-nodes. The following | ||
# is sufficient for patterns with a single output-node "node", which can serve as the | ||
# insertion-point. | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Apply basic constant propagation to newly created nodes | ||
# Note: For patterns with multiple output nodes, the insertion point | ||
# is determined by the convenience.replace_nodes_and_values function. | ||
# This works correctly for most cases, but complex patterns with | ||
# specific ordering requirements may need additional consideration. | ||
onnxscript.optimizer.basic_constant_propagation(delta.new_nodes) | ||
if rule.as_function: | ||
# Create a function out of a copy of the matched nodes | ||
|
@@ -573,18 +658,30 @@ | |
The number of applications of rewrite rules. | ||
""" | ||
assert isinstance(model, ir.Model) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Apply initial constant propagation once at the start | ||
onnxscript.optimizer.basic_constant_propagation(model.graph) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Rewriting may introduce new functions. In the following loop, | ||
# we restrict rewriting to original functions, not newly introduced ones. | ||
original_functions = list(model.functions.values()) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Apply constant propagation to original functions before rewriting | ||
for function in original_functions: | ||
onnxscript.optimizer.basic_constant_propagation(function) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Apply rewrite rules to main graph | ||
count = self._apply_to_graph_or_function( | ||
model, model.graph, verbose=verbose, tracer=tracer | ||
) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Apply rewrite rules to original functions | ||
for function in original_functions: | ||
onnxscript.optimizer.basic_constant_propagation(function) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this unnecessary here? |
||
count += self._apply_to_graph_or_function( | ||
model, function, verbose=verbose, tracer=tracer | ||
) | ||
|
||
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
|
||
# Final cleanup if needed | ||
if self.remove_unused_nodes: | ||
onnxscript.optimizer.remove_unused_nodes(model) | ||
return count | ||
|
Check warning
Code scanning / CodeQL
Unreachable code Warning
Copilot Autofix
AI 5 days ago
To fix the issue, the unreachable
except Exception:
block should be removed. This simplifies the code and makes it clearer, while preserving its intended functionality. The equality comparison (value1 == value2
) is robust enough for standard use cases, and the removal of the unreachable code does not affect the behavior of the function.Changes to make:
try
block and theexcept Exception:
block entirely.try
block with a direct equality comparison (return value1 == value2
).