Skip to content

[docs] Document rewriter pattern options #2406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/tutorial/rewriter/allow_other_inputs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Specifying variable inputs in the pattern

This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting.
The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs
beyond those specified in the pattern. If it is set to `False` (the default), then the node must
have exactly the specified inputs for a successful match. If set to `True`, the pattern will
match nodes that have the specified inputs plus any number of additional inputs.

This is particularly useful when matching operations like `Conv` that can have optional inputs
(such as bias), or when creating generic patterns that should work with various input configurations.

```{literalinclude} examples/allow_other_inputs.py
:pyobject: conv_pattern
```

```{literalinclude} examples/allow_other_inputs.py
:pyobject: conv_replacement
```

```{literalinclude} examples/allow_other_inputs.py
:pyobject: apply_rewrite
```

In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation
might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting
`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs
in the pattern definition.
1 change: 1 addition & 0 deletions docs/tutorial/rewriter/attributes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This section demonstrates the use of attribute values in pattern-based rewriting
First, write a target pattern and replacement pattern in a similar way to the previous examples.
The example pattern below will match successfully only against Dropout nodes with the
attribute value `training_mode` set to `False`.

The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes
not specified in the pattern. If it is set to `False`, then the node must have only the specified
attribute values, and no other attributes, for a successful match. The default value for this
Expand Down
38 changes: 38 additions & 0 deletions docs/tutorial/rewriter/domain_option.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Specifying domains in the pattern

This section demonstrates the use of the `_domain` option in pattern-based rewriting.
The `_domain` option allows you to specify which operator domain the pattern should match against,
and also allows you to create replacement operations in specific domains.

ONNX operators can belong to different domains:
- The default ONNX domain (empty string or "ai.onnx")
- Custom domains like "com.microsoft" for Microsoft-specific operations
- User-defined domains for custom operations

## Matching operations from a specific domain

```{literalinclude} examples/domain_option.py
:pyobject: custom_relu_pattern
```

In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the
"custom.domain" domain will be matched, not standard ONNX `Relu` operations.

## Creating replacement operations in a specific domain

```{literalinclude} examples/domain_option.py
:pyobject: microsoft_relu_replacement
```

Here, the replacement operation is created in the "com.microsoft" domain, which might
provide optimized implementations of standard operations.

## Complete rewrite example

```{literalinclude} examples/domain_option.py
:pyobject: apply_rewrite
```

This example shows how domain-specific pattern matching can be used to migrate operations
between different operator domains, such as replacing custom domain operations with
standard ONNX operations or vice versa.
71 changes: 71 additions & 0 deletions docs/tutorial/rewriter/examples/allow_other_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""ONNX Pattern Rewriting with variable number of inputs

This script shows how to define a rewriting rule based on patterns that
can match nodes with additional inputs beyond those specified in the pattern.
"""

import onnx

import onnxscript
from onnxscript import FLOAT, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]:
# Conv with bias - has 3 inputs: input, weight, bias
result = opset18.Conv(A, B, C)
return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# The target pattern
# =====================


def conv_pattern(op, input, weight):
# Pattern to match Conv operations, allowing additional inputs like bias
# _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs)
# even though we only specify 2 inputs in the pattern
return op.Conv(input, weight, _allow_other_inputs=True)


####################################
# The replacement pattern
# =====================


def conv_replacement(op, input, weight, **_):
# Replace with a custom operation in a different domain
return op.OptimizedConv(input, weight, _domain="custom.domain")


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
# Create rewrite rules
conv_rule = pattern.RewriteRule(
conv_pattern, # target pattern
conv_replacement, # replacement pattern
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([conv_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)
86 changes: 86 additions & 0 deletions docs/tutorial/rewriter/examples/domain_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""ONNX Pattern Rewriting with domain specification

This script shows how to define a rewriting rule that targets operations
from specific domains and replaces them with operations in other domains.
"""

import onnx

import onnxscript
from onnxscript import script
from onnxscript.rewriter import pattern
from onnxscript.values import Opset

# Create an opset for the custom domain
opset = Opset("custom.domain", 1)


@script(opset)
def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]:
"""Create a model with a Relu operation in a custom domain."""
return opset.Relu(input)


_model = create_model_with_custom_domain.to_model_proto()
_model = onnx.shape_inference.infer_shapes(_model)
onnx.checker.check_model(_model)


####################################
# The target pattern
# =====================


def custom_relu_pattern(op, input):
# Pattern to match Relu operations from a specific domain
# _domain="custom.domain" specifies we only want to match operations from this domain
return op.Relu(input, _domain="custom.domain")


####################################
# The replacement pattern
# =====================


def standard_relu_replacement(op, input, **_):
# Replace with standard ONNX Relu (default domain)
return op.Relu(input)


####################################
# Alternative: Replace with operation in different domain
# =====================


def microsoft_relu_replacement(op, input, **_):
# Replace with operation in Microsoft's domain
return op.OptimizedRelu(input, _domain="com.microsoft")


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
# Create rewrite rules
relu_rule = pattern.RewriteRule(
custom_relu_pattern, # target pattern - matches custom domain operations
standard_relu_replacement, # replacement pattern - uses standard domain
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([relu_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


# The rewrite rule will now match the Relu operation in the custom domain
# and replace it with a standard ONNX Relu operation
_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)
75 changes: 75 additions & 0 deletions docs/tutorial/rewriter/examples/outputs_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""ONNX Pattern Rewriting with output specification

This script shows how to define a rewriting rule that specifies
the number and names of outputs from operations.
"""

import onnx

import onnxscript
from onnxscript import FLOAT, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]:
# Split operation that produces 2 outputs
result1, result2 = opset18.Split(A, num_outputs=2, axis=0)

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning documentation

Unused variable 'result2' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/RUF059 Warning documentation

Unpacked variable result2 is never used.
See https://docs.astral.sh/ruff/rules/unused-unpacked-variable
# We only return the first output for simplicity
return result1


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# The target pattern with multiple outputs
# =====================


def split_pattern(op, input):
# Pattern to match Split operations with 2 outputs
# _outputs=2 specifies that this operation produces 2 outputs
return op.Split(input, num_outputs=2, axis=0, _outputs=2)


####################################
# The replacement pattern with named outputs
# =====================


def custom_split_replacement(op, input, **_):
# Replace with a custom split operation using named outputs
# _outputs=["first_half", "second_half"] assigns names to the outputs
# IMPORTANT: The number of outputs must match the pattern (2 outputs)
return op.CustomSplit(
input, _domain="custom.domain", _outputs=["first_half", "second_half"]
)


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
# Create rewrite rules
split_rule = pattern.RewriteRule(
split_pattern, # target pattern - matches Split with 2 outputs
custom_split_replacement, # replacement pattern - uses named outputs
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([split_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)
43 changes: 43 additions & 0 deletions docs/tutorial/rewriter/outputs_option.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Specifying outputs in the pattern

This section demonstrates the use of the `_outputs` option in pattern-based rewriting.
The `_outputs` option allows you to specify the number of outputs an operation produces
and optionally assign names to those outputs for easier reference in replacement patterns.

The `_outputs` option can be specified in two ways:
- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs
- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs

## Matching operations with multiple outputs

```{literalinclude} examples/outputs_option.py
:pyobject: split_pattern
```

This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2`
specification ensures the pattern only matches operations with this specific output count.

## Creating replacement operations with named outputs

```{literalinclude} examples/outputs_option.py
:pyobject: custom_split_replacement
```

In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with
descriptive names. This can make the replacement pattern more readable and maintainable.

**Important**: The number of outputs in the replacement pattern must match the number of
outputs in the target pattern. Since the pattern specifies `_outputs=2`, the replacement
must also produce exactly 2 outputs.

## Complete rewrite example

```{literalinclude} examples/outputs_option.py
:pyobject: apply_rewrite
```

The `_outputs` option is particularly important when:
- Working with operations that have variable numbers of outputs (like `Split`)
- Creating custom operations that need specific output configurations
- Ensuring pattern matching precision by specifying exact output counts
- Improving code readability by naming outputs in replacement patterns
20 changes: 20 additions & 0 deletions docs/tutorial/rewriter/rewrite_patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,32 @@ There are three main components needed when rewriting patterns in the graph:
2. `replacement_pattern` : Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators.
3. `match_condition` (optional) : Pattern rewrite will occur only if the match condition is satisfied.

## Pattern Options

When defining patterns, you can use several special options to control how patterns match and what they produce:

- `_allow_other_attributes`: Controls whether the pattern allows additional attributes not specified in the pattern (default: True)
- `_allow_other_inputs`: Controls whether the pattern allows additional inputs beyond those specified (default: False)
- `_domain`: Specifies the operator domain for matching or creating operations
- `_outputs`: Specifies the number and optionally names of outputs from an operation

These options are documented in detail in the following sections.

```{include} simple_example.md
```

```{include} attributes.md
```

```{include} allow_other_inputs.md
```

```{include} domain_option.md
```

```{include} outputs_option.md
```

```{include} conditional_rewrite.md
```

Expand Down
Loading