Skip to content

[Rewriter]: fuse successive Relu/Clip nodes #2410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"packaging",
"protobuf",
)
ONNX_IR = "onnx_ir==0.1.1"
ONNX_IR = "onnx_ir==0.1.3"
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"


Expand Down
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
fuse_relus_clips,
no_op,
pattern,
redundant_scatter_nd,
Expand All @@ -32,6 +33,7 @@
*broadcast_to_matmul.rules.rules,
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*fuse_relus_clips.fuse_relus_clips_rules().rules,
*basic_rules.basic_optimization_rules().rules,
*redundant_scatter_nd.rules.rules,
)
Expand Down
190 changes: 190 additions & 0 deletions onnxscript/rewriter/fuse_relus_clips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Does the following transformation:
- Relu(Relu(X)) -> Relu
- Relu(Clip(X)) -> Clip
- Clip(Relu(X)) -> Clip
- Clip(Clip(X)) -> Clip
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from __future__ import annotations would be helpful

import abc

import numpy as np
import onnx_ir as ir

from onnxscript.rewriter import pattern as orp


class FuseSuccessiveRelu(orp.RewriteRuleClassBase):
"""Replaces ``Relu(Relu(X))`` with ``Relu(X)``."""

def rewrite(self, op, x):
return op.Relu(x)

def pattern(self, op, x):
return op.Relu(op.Relu(x))


class _FuseReluClipBase(orp.RewriteRuleClassBase, abc.ABC):
def rewrite(self, op, x, **kwargs):
first_clip_node = kwargs.get("out_first_clip").producer()
second_clip_node = None

if out_second_clip := kwargs.get("out_second_clip"):
second_clip_node = out_second_clip.producer()

min_clip, max_clip = self.compute_clip_min_max(first_clip_node, second_clip_node)
clip_min_max = []

if min_clip is not None:
clip_min_max.append(
op.initializer(min_clip, name=f"{first_clip_node.inputs[0].name}_min")
)

if max_clip is not None:
# ONNX Clip expects min and max inputs in order.
# If min is not provided, we insert None to maintain correct argument positions.
if min_clip is None:
clip_min_max.append(None)

clip_min_max.append(
op.initializer(max_clip, name=f"{first_clip_node.inputs[0].name}_max")
)

return op.Clip(x, *clip_min_max)

@abc.abstractmethod
def compute_clip_min_max(
self, first_clip_node: ir.Node, second_clip_node: ir.Node | None = None
):
pass

Check warning on line 60 in onnxscript/rewriter/fuse_relus_clips.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/fuse_relus_clips.py#L60

Added line #L60 was not covered by tests

def extract_min_max(self, node: ir.Node):
# Infer dtype from node first input
dtype = node.inputs[0].dtype.numpy()
min_clip, max_clip = None, None

if len(node.inputs) > 1:
min_input = node.inputs[1]
# If only a max is provided, min is implicitly None, so we check that
if min_input is not None:
min_clip = min_input.const_value.numpy()

if len(node.inputs) > 2:
max_clip = node.inputs[2].const_value.numpy()

return min_clip, max_clip, dtype

def check(self, context, **kwargs):
"""Condition to check if we need to replace the pattern.

The pattern is applied only when the min and max inputs of the Clip nodes are
not graph inputs and are constant values (i.e., provided by Constant nodes or initializers).

Returns:
MatchResult:
Success if we need to replace the pattern, Failure otherwise.
"""
del context # Unused
check_result = orp.MatchResult()

# Check if Clip min/max are not graph inputs and are constant values
clip_min_max = []

first_clip_node = kwargs.get("out_first_clip").producer()
clip_min_max.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None])

if out_second_clip := kwargs.get("out_second_clip"):
second_clip_node = out_second_clip.producer()
clip_min_max.extend(
[inp for inp in second_clip_node.inputs[1:] if inp is not None]
)

for m in clip_min_max:
if m.is_graph_input():
return check_result.fail(f"{m.name} is a graph input.")

if ir.convenience.get_const_tensor(m) is None:
return check_result.fail(f"{m.name} is not a constant.")

return check_result


class FuseSuccessiveClip(_FuseReluClipBase):
"""Replaces ``Clip(Clip(X))`` with ``Clip(X)``."""

def pattern(self, op, x):
return op.Clip(
op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]),
_allow_other_inputs=True,
_outputs=["out_second_clip"],
)

def compute_clip_min_max(self, first_clip_node: ir.Node, second_clip_node: ir.Node):
min_clip1, max_clip1, dtype = self.extract_min_max(first_clip_node)
min_clip2, max_clip2, _ = self.extract_min_max(second_clip_node)

def combine(val1, val2, op):
if val1 is not None and val2 is not None:
return ir.tensor(np.array(op(val1, val2), dtype=dtype))
elif val1 is not None:
return ir.tensor(val1)
elif val2 is not None:
return ir.tensor(val2)
return None

min_clip = combine(min_clip1, min_clip2, np.maximum)
max_clip = combine(max_clip1, max_clip2, np.minimum)

return min_clip, max_clip


class FuseSuccessiveClipRelu(_FuseReluClipBase):
"""Replaces ``Clip(Relu(X))`` with ``Clip(X)``."""

def pattern(self, op, x):
return op.Clip(op.Relu(x), _allow_other_inputs=True, _outputs=["out_first_clip"])

def compute_clip_min_max(self, first_clip_node: ir.Node, _):
min_clip, max_clip, dtype = self.extract_min_max(first_clip_node)

if min_clip is None:
# The minimum clipping value is implicitly 0 (Relu clamps at 0)
min_clip = 0

min_clip = ir.tensor(np.array(np.maximum(0.0, min_clip), dtype=dtype))

if max_clip is not None:
max_clip = ir.tensor(max_clip)
return min_clip, max_clip


class FuseSuccessiveReluClip(FuseSuccessiveClipRelu):
"""Replaces ``Relu(Clip(X))`` with ``Clip(X)``."""

def pattern(self, op, x):
return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]))


fuse_successive_relu_rule = FuseSuccessiveRelu().rule()
fuse_successive_clip_rule = FuseSuccessiveClip().rule()
fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule()
fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule()


def fuse_relus_clips_rules() -> orp.RewriteRuleSet:
"""Returns a set of rewrite rules that fuse successive Relu/Clip nodes.

Returns:
RewriteRuleSet
"""

# Order is important
return orp.RewriteRuleSet(
[
fuse_successive_clip_relu_rule,
fuse_successive_relu_clip_rule,
fuse_successive_relu_rule,
fuse_successive_clip_rule,
]
)
Loading
Loading