-
Notifications
You must be signed in to change notification settings - Fork 71
[docs] Document rewriter pattern options #2406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Copilot
wants to merge
8
commits into
main
Choose a base branch
from
copilot/fix-2405
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
99cef5c
Initial plan for issue
Copilot 40ea7ac
Add documentation for rewriter pattern options
Copilot d082599
Fix domain_option.py example to use model with custom domain operations
Copilot 9c2d145
Remove incorrect single and triple output replacements from outputs_o…
Copilot aaceba9
Use onnxscript.script with custom Opset in domain_option.py example
Copilot 12f0fea
Fix whitespace issues in documentation files
Copilot 2f72d62
Merge branch 'main' into copilot/fix-2405
gramalingam 2a6889e
Fix linting issues identified by lintrunner
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Specifying variable inputs in the pattern | ||
|
||
This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting. | ||
The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs | ||
beyond those specified in the pattern. If it is set to `False` (the default), then the node must | ||
have exactly the specified inputs for a successful match. If set to `True`, the pattern will | ||
match nodes that have the specified inputs plus any number of additional inputs. | ||
|
||
This is particularly useful when matching operations like `Conv` that can have optional inputs | ||
(such as bias), or when creating generic patterns that should work with various input configurations. | ||
|
||
```{literalinclude} examples/allow_other_inputs.py | ||
:pyobject: conv_pattern | ||
``` | ||
|
||
```{literalinclude} examples/allow_other_inputs.py | ||
:pyobject: conv_replacement | ||
``` | ||
|
||
```{literalinclude} examples/allow_other_inputs.py | ||
:pyobject: apply_rewrite | ||
``` | ||
|
||
In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation | ||
might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting | ||
`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs | ||
in the pattern definition. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Specifying domains in the pattern | ||
|
||
This section demonstrates the use of the `_domain` option in pattern-based rewriting. | ||
The `_domain` option allows you to specify which operator domain the pattern should match against, | ||
and also allows you to create replacement operations in specific domains. | ||
|
||
ONNX operators can belong to different domains: | ||
- The default ONNX domain (empty string or "ai.onnx") | ||
- Custom domains like "com.microsoft" for Microsoft-specific operations | ||
- User-defined domains for custom operations | ||
|
||
## Matching operations from a specific domain | ||
|
||
```{literalinclude} examples/domain_option.py | ||
:pyobject: custom_relu_pattern | ||
``` | ||
|
||
In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the | ||
"custom.domain" domain will be matched, not standard ONNX `Relu` operations. | ||
|
||
## Creating replacement operations in a specific domain | ||
|
||
```{literalinclude} examples/domain_option.py | ||
:pyobject: microsoft_relu_replacement | ||
``` | ||
|
||
Here, the replacement operation is created in the "com.microsoft" domain, which might | ||
provide optimized implementations of standard operations. | ||
|
||
## Complete rewrite example | ||
|
||
```{literalinclude} examples/domain_option.py | ||
:pyobject: apply_rewrite | ||
``` | ||
|
||
This example shows how domain-specific pattern matching can be used to migrate operations | ||
between different operator domains, such as replacing custom domain operations with | ||
standard ONNX operations or vice versa. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""ONNX Pattern Rewriting with variable number of inputs | ||
|
||
This script shows how to define a rewriting rule based on patterns that | ||
can match nodes with additional inputs beyond those specified in the pattern. | ||
""" | ||
|
||
import onnx | ||
|
||
import onnxscript | ||
from onnxscript import FLOAT, opset18, script | ||
from onnxscript.rewriter import pattern | ||
|
||
|
||
@script() | ||
def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]: | ||
# Conv with bias - has 3 inputs: input, weight, bias | ||
result = opset18.Conv(A, B, C) | ||
return result | ||
|
||
|
||
_model = original_model.to_model_proto() | ||
onnx.checker.check_model(_model) | ||
|
||
|
||
#################################### | ||
# The target pattern | ||
# ===================== | ||
|
||
|
||
def conv_pattern(op, input, weight): | ||
# Pattern to match Conv operations, allowing additional inputs like bias | ||
# _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs) | ||
# even though we only specify 2 inputs in the pattern | ||
return op.Conv(input, weight, _allow_other_inputs=True) | ||
|
||
|
||
#################################### | ||
# The replacement pattern | ||
# ===================== | ||
|
||
|
||
def conv_replacement(op, input, weight, **_): | ||
# Replace with a custom operation in a different domain | ||
return op.OptimizedConv(input, weight, _domain="custom.domain") | ||
|
||
|
||
#################################### | ||
# Create Rewrite Rule and Apply to Model | ||
# ===================== | ||
|
||
|
||
def apply_rewrite(model): | ||
# Create rewrite rules | ||
conv_rule = pattern.RewriteRule( | ||
conv_pattern, # target pattern | ||
conv_replacement, # replacement pattern | ||
) | ||
# Create a Rewrite Rule Set | ||
rewrite_rule_set = pattern.RewriteRuleSet([conv_rule]) | ||
# Apply rewrite | ||
model_with_rewrite = onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=rewrite_rule_set, | ||
) | ||
return model_with_rewrite | ||
|
||
|
||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""ONNX Pattern Rewriting with domain specification | ||
|
||
This script shows how to define a rewriting rule that targets operations | ||
from specific domains and replaces them with operations in other domains. | ||
""" | ||
|
||
import onnx | ||
|
||
import onnxscript | ||
from onnxscript import script | ||
from onnxscript.rewriter import pattern | ||
from onnxscript.values import Opset | ||
|
||
# Create an opset for the custom domain | ||
opset = Opset("custom.domain", 1) | ||
|
||
|
||
@script(opset) | ||
def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]: | ||
"""Create a model with a Relu operation in a custom domain.""" | ||
return opset.Relu(input) | ||
|
||
|
||
_model = create_model_with_custom_domain.to_model_proto() | ||
_model = onnx.shape_inference.infer_shapes(_model) | ||
onnx.checker.check_model(_model) | ||
|
||
|
||
#################################### | ||
# The target pattern | ||
# ===================== | ||
|
||
|
||
def custom_relu_pattern(op, input): | ||
# Pattern to match Relu operations from a specific domain | ||
# _domain="custom.domain" specifies we only want to match operations from this domain | ||
return op.Relu(input, _domain="custom.domain") | ||
|
||
|
||
#################################### | ||
# The replacement pattern | ||
# ===================== | ||
|
||
|
||
def standard_relu_replacement(op, input, **_): | ||
# Replace with standard ONNX Relu (default domain) | ||
return op.Relu(input) | ||
|
||
|
||
#################################### | ||
# Alternative: Replace with operation in different domain | ||
# ===================== | ||
|
||
|
||
def microsoft_relu_replacement(op, input, **_): | ||
# Replace with operation in Microsoft's domain | ||
return op.OptimizedRelu(input, _domain="com.microsoft") | ||
|
||
|
||
#################################### | ||
# Create Rewrite Rule and Apply to Model | ||
# ===================== | ||
|
||
|
||
def apply_rewrite(model): | ||
# Create rewrite rules | ||
relu_rule = pattern.RewriteRule( | ||
custom_relu_pattern, # target pattern - matches custom domain operations | ||
standard_relu_replacement, # replacement pattern - uses standard domain | ||
) | ||
# Create a Rewrite Rule Set | ||
rewrite_rule_set = pattern.RewriteRuleSet([relu_rule]) | ||
# Apply rewrite | ||
model_with_rewrite = onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=rewrite_rule_set, | ||
) | ||
return model_with_rewrite | ||
|
||
|
||
# The rewrite rule will now match the Relu operation in the custom domain | ||
# and replace it with a standard ONNX Relu operation | ||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""ONNX Pattern Rewriting with output specification | ||
|
||
This script shows how to define a rewriting rule that specifies | ||
the number and names of outputs from operations. | ||
""" | ||
|
||
import onnx | ||
|
||
import onnxscript | ||
from onnxscript import FLOAT, opset18, script | ||
from onnxscript.rewriter import pattern | ||
|
||
|
||
@script() | ||
def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]: | ||
# Split operation that produces 2 outputs | ||
result1, result2 = opset18.Split(A, num_outputs=2, axis=0) | ||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning documentation
Unused variable 'result2' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/RUF059 Warning documentation
Unpacked variable result2 is never used.
See https://docs.astral.sh/ruff/rules/unused-unpacked-variable |
||
# We only return the first output for simplicity | ||
return result1 | ||
|
||
|
||
_model = original_model.to_model_proto() | ||
onnx.checker.check_model(_model) | ||
|
||
|
||
#################################### | ||
# The target pattern with multiple outputs | ||
# ===================== | ||
|
||
|
||
def split_pattern(op, input): | ||
# Pattern to match Split operations with 2 outputs | ||
# _outputs=2 specifies that this operation produces 2 outputs | ||
return op.Split(input, num_outputs=2, axis=0, _outputs=2) | ||
|
||
|
||
#################################### | ||
# The replacement pattern with named outputs | ||
# ===================== | ||
|
||
|
||
def custom_split_replacement(op, input, **_): | ||
# Replace with a custom split operation using named outputs | ||
# _outputs=["first_half", "second_half"] assigns names to the outputs | ||
# IMPORTANT: The number of outputs must match the pattern (2 outputs) | ||
return op.CustomSplit( | ||
input, _domain="custom.domain", _outputs=["first_half", "second_half"] | ||
) | ||
|
||
|
||
#################################### | ||
# Create Rewrite Rule and Apply to Model | ||
# ===================== | ||
|
||
|
||
def apply_rewrite(model): | ||
# Create rewrite rules | ||
split_rule = pattern.RewriteRule( | ||
split_pattern, # target pattern - matches Split with 2 outputs | ||
custom_split_replacement, # replacement pattern - uses named outputs | ||
) | ||
# Create a Rewrite Rule Set | ||
rewrite_rule_set = pattern.RewriteRuleSet([split_rule]) | ||
# Apply rewrite | ||
model_with_rewrite = onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=rewrite_rule_set, | ||
) | ||
return model_with_rewrite | ||
|
||
|
||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Specifying outputs in the pattern | ||
|
||
This section demonstrates the use of the `_outputs` option in pattern-based rewriting. | ||
The `_outputs` option allows you to specify the number of outputs an operation produces | ||
and optionally assign names to those outputs for easier reference in replacement patterns. | ||
|
||
The `_outputs` option can be specified in two ways: | ||
- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs | ||
- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs | ||
|
||
## Matching operations with multiple outputs | ||
|
||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: split_pattern | ||
``` | ||
|
||
This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2` | ||
specification ensures the pattern only matches operations with this specific output count. | ||
|
||
## Creating replacement operations with named outputs | ||
|
||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: custom_split_replacement | ||
``` | ||
|
||
In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with | ||
descriptive names. This can make the replacement pattern more readable and maintainable. | ||
|
||
**Important**: The number of outputs in the replacement pattern must match the number of | ||
outputs in the target pattern. Since the pattern specifies `_outputs=2`, the replacement | ||
must also produce exactly 2 outputs. | ||
|
||
## Complete rewrite example | ||
|
||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: apply_rewrite | ||
``` | ||
|
||
The `_outputs` option is particularly important when: | ||
- Working with operations that have variable numbers of outputs (like `Split`) | ||
- Creating custom operations that need specific output configurations | ||
- Ensuring pattern matching precision by specifying exact output counts | ||
- Improving code readability by naming outputs in replacement patterns |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.