Skip to content

Create torch_compile_conv_bn_fuser tutorial adapted from fx_conv_bn_fuser #3458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"beginner_source/examples_autograd/polynomial_autograd",
"beginner_source/examples_autograd/polynomial_custom_function",
"intermediate_source/mnist_train_nas", # used by ax_multiobjective_nas_tutorial.py
"intermediate_source/fx_conv_bn_fuser",
"intermediate_source/torch_compile_conv_bn_fuser",
"intermediate_source/_torch_export_nightly_tutorial", # does not work on release
"advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker
"prototype_source/fx_graph_mode_ptq_dynamic",
Expand Down
16 changes: 8 additions & 8 deletions index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,6 @@ Welcome to PyTorch Tutorials

.. Code Transformations with FX

.. customcarditem::
:header: Building a Convolution/Batch Norm fuser in FX
:card_description: Build a simple FX pass that fuses batch norm into convolution to improve performance during inference.
:image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png
:link: intermediate/fx_conv_bn_fuser.html
:tags: FX

.. customcarditem::
:header: Building a Simple Performance Profiler with FX
:card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics
Expand Down Expand Up @@ -583,6 +576,13 @@ Welcome to PyTorch Tutorials
:link: intermediate/torch_compile_tutorial.html
:tags: Model-Optimization

.. customcarditem::
:header: Building a Convolution/Batch Norm fuser in torch.compile
:card_description: Build a simple pattern matcher pass that fuses batch norm into convolution to improve performance during inference.
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: intermediate/torch_compile_conv_bn_fuser.html
:tags: Model-Optimization

.. customcarditem::
:header: Inductor CPU Backend Debugging and Profiling
:card_description: Learn the usage, debugging and performance profiling for ``torch.compile`` with Inductor CPU backend.
Expand Down Expand Up @@ -950,7 +950,6 @@ Additional Resources
:hidden:
:caption: Code Transforms with FX

intermediate/fx_conv_bn_fuser
intermediate/fx_profiling_tutorial

.. toctree::
Expand Down Expand Up @@ -1001,6 +1000,7 @@ Additional Resources
intermediate/nvfuser_intro_tutorial
intermediate/ax_multiobjective_nas_tutorial
intermediate/torch_compile_tutorial
intermediate/torch_compile_conv_bn_fuser
intermediate/compiled_autograd_tutorial
intermediate/inductor_debug_cpu
intermediate/scaled_dot_product_attention_tutorial
Expand Down
274 changes: 152 additions & 122 deletions intermediate_source/fx_conv_bn_fuser.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
# -*- coding: utf-8 -*-
"""
(beta) Building a Convolution/Batch Norm fuser in FX
*******************************************************
**Author**: `Horace He <https://github.com/chillee>`_
Building a Convolution/Batch Norm fuser with torch.compile
===========================================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do yo want to rename this file to torch_compile_conv_bn_fuser?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg! I just renamed it


In this tutorial, we are going to use FX, a toolkit for composable function
transformations of PyTorch, to do the following:
**Author:** `Horace He <https://github.com/chillee>`_, `Will Feng <https://github.com/yf225>`_

1) Find patterns of conv/batch norm in the data dependencies.
2) For the patterns found in 1), fold the batch norm statistics into the convolution weights.
.. grid:: 2

Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites

We will be building the fuser that exists here:
https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
* How to register custom fusion patterns with torch.compile's pattern matcher

.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites

* PyTorch v2.7.0

.. note::
This optimization only works for models in inference mode (i.e. ``model.eval()``).
However, torch.compile's pattern matching system works for both training and inference.

"""

Expand All @@ -24,10 +30,11 @@

from typing import Type, Dict, Any, Tuple, Iterable
import copy
import torch.fx as fx
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################
# For this tutorial, we are going to create a model consisting of convolutions
# and batch norms. Note that this model has some tricky components - some of
Expand Down Expand Up @@ -61,29 +68,26 @@ def forward(self, x):
x = self.wrapped(x)
return x

model = M()

model = M().to(device)
model.eval()

######################################################################
# Fusing Convolution with Batch Norm
# -----------------------------------------
# One of the primary challenges with trying to automatically fuse convolution
# and batch norm in PyTorch is that PyTorch does not provide an easy way of
# accessing the computational graph. FX resolves this problem by symbolically
# tracing the actual operations called, so that we can track the computations
# through the `forward` call, nested within Sequential modules, or wrapped in
# an user-defined module.

traced_model = torch.fx.symbolic_trace(model)
print(traced_model.graph)
# accessing the computational graph. torch.compile resolves this problem by
# capturing the computational graph during compilation, allowing us to apply
# pattern-based optimizations across the entire model, including operations
# nested within Sequential modules or wrapped in custom modules.
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import register_replacement

######################################################################
# This gives us a graph representation of our model. Note that both the modules
# hidden within the sequential as well as the wrapped Module have been inlined
# into the graph. This is the default level of abstraction, but it can be
# configured by the pass writer. More information can be found at the FX
# overview https://pytorch.org/docs/master/fx.html#module-torch.fx
# torch.compile will capture a graph representation of our model. During
# compilation, modules hidden within Sequential containers and wrapped
# modules are all inlined into the graph, making them available for
# pattern matching and optimization.


####################################
Expand Down Expand Up @@ -128,78 +132,74 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):


####################################
# FX Fusion Pass
# ----------------------------------
# Now that we have our computational graph as well as a method for fusing
# convolution and batch norm, all that remains is to iterate over the FX graph
# and apply the desired fusions.


def _parent_name(target : str) -> Tuple[str, str]:
"""
Splits a ``qualname`` into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name

def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, new_module)


def fuse(model: torch.nn.Module) -> torch.nn.Module:
model = copy.deepcopy(model)
# The first step of most FX passes is to symbolically trace our model to
# obtain a `GraphModule`. This is a representation of our original model
# that is functionally identical to our original model, except that we now
# also have a graph representation of our forward pass.
fx_model: fx.GraphModule = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())

# The primary representation for working with FX are the `Graph` and the
# `Node`. Each `GraphModule` has a `Graph` associated with it - this
# `Graph` is also what generates `GraphModule.code`.
# The `Graph` itself is represented as a list of `Node` objects. Thus, to
# iterate through all of the operations in our graph, we iterate over each
# `Node` in our `Graph`.
for node in fx_model.graph.nodes:
# The FX IR contains several types of nodes, which generally represent
# call sites to modules, functions, or methods. The type of node is
# determined by `Node.op`.
if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
continue
# For call sites, `Node.target` represents the module/function/method
# that's being called. Here, we check `Node.target` to see if it's a
# batch norm module, and then check `Node.args[0].target` to see if the
# input `Node` is a convolution.
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
# As we've folded the batch nor into the conv, we need to replace all uses
# of the batch norm with the conv.
node.replace_all_uses_with(node.args[0])
# Now that all uses of the batch norm have been replaced, we can
# safely remove the batch norm.
fx_model.graph.erase_node(node)
fx_model.graph.lint()
# After we've modified our graph, we need to recompile our graph in order
# to keep the generated code in sync.
fx_model.recompile()
return fx_model
# Pattern Matching with torch.compile
# ------------------------------------
# Now that we have our fusion logic, we need to register a pattern that
# torch.compile's pattern matcher will recognize and replace during
# compilation.

# Define the pattern we want to match: conv2d followed by batch_norm
def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias)
bn_out = torch.nn.functional.batch_norm(
conv_out, bn_mean, bn_var, bn_weight, bn_bias,
training=False, eps=1e-5
)
return bn_out

def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
fused_weight, fused_bias = fuse_conv_bn_weights(
conv_weight, conv_bias, bn_mean, bn_var, 1e-5, bn_weight, bn_bias
)
return torch.nn.functional.conv2d(x, fused_weight, fused_bias)

# Example inputs are needed to trace the pattern functions.
# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.
# These are used to trace the pattern functions to create the match template.
# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here
# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence
# will be matched regardless of channels, kernel size, or spatial dimensions.
# - x: input tensor (batch_size, channels, height, width)
# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)
# - conv_bias: (out_channels,)
# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels
example_inputs = [
torch.randn(1, 1, 4, 4).to(device), # x: input tensor
torch.randn(1, 1, 1, 1).to(device), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel
torch.randn(1).to(device), # conv_bias: 1 output channel
torch.randn(1).to(device), # bn_mean: batch norm running mean
torch.randn(1).to(device), # bn_var: batch norm running variance
torch.randn(1).to(device), # bn_weight: batch norm weight (gamma)
torch.randn(1).to(device), # bn_bias: batch norm bias (beta)
]

from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._inductor import config

# Create a pattern matcher pass and register our pattern
patterns = PatternMatcherPass()

register_replacement(
conv_bn_pattern,
conv_bn_replacement,
example_inputs,
pm.fwd_only,
patterns,
)

# Create a custom pass function that applies our patterns
def conv_bn_fusion_pass(graph):
return patterns.apply(graph)

# Set our custom pass in the config
config.post_grad_custom_post_pass = conv_bn_fusion_pass


######################################################################
# .. note::
# We make some simplifications here for demonstration purposes, such as only
# matching 2D convolutions. View
# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
# for a more usable pass.
# matching 2D convolutions. The pattern matcher in torch.compile
# can handle more complex patterns.

######################################################################
# Testing out our Fusion Pass
Expand All @@ -208,11 +208,43 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module:
# results are identical. In addition, we can print out the code for our fused
# model and verify that there are no more batch norms.

from torch._dynamo.utils import counters

fused_model = fuse(model)
print(fused_model.code)
inp = torch.randn(5, 1, 1, 1)
torch.testing.assert_allclose(fused_model(inp), model(inp))
# Clear the counters before compilation
counters.clear()

# Ensure pattern matcher is enabled
config.pattern_matcher = True

fused_model = torch.compile(model, backend="inductor")
inp = torch.randn(5, 1, 1, 1).to(device)

# Run the model to trigger compilation and pattern matching
with torch.no_grad():
output = fused_model(inp)
expected = model(inp)
torch.testing.assert_close(output, expected)

# Check how many patterns were matched
assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched"

# Create a model with different shapes than our example_inputs
test_model_diff_shape = nn.Sequential(
nn.Conv2d(3, 16, 5),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, 7),
nn.BatchNorm2d(32),
).to(device).eval()

counters.clear()
compiled_diff_shape = torch.compile(test_model_diff_shape, backend="inductor")
test_input_diff_shape = torch.randn(1, 3, 28, 28).to(device)
with torch.no_grad():
compiled_diff_shape(test_input_diff_shape)

# Check how many patterns were matched
assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched"


######################################################################
Expand All @@ -223,40 +255,38 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module:
import torchvision.models as models
import time

rn18 = models.resnet18()
rn18 = models.resnet18().to(device)
rn18.eval()

inp = torch.randn(10, 3, 224, 224)
inp = torch.randn(10, 3, 224, 224).to(device)
output = rn18(inp)

def benchmark(model, iters=20):
for _ in range(10):
model(inp)
begin = time.time()
for _ in range(iters):
model(inp)
return str(time.time()-begin)

fused_rn18 = fuse(rn18)
print("Unfused time: ", benchmark(rn18))
print("Fused time: ", benchmark(fused_rn18))
######################################################################
# As we previously saw, the output of our FX transformation is
# ("torchscriptable") PyTorch code, we can easily ``jit.script`` the output to try
# and increase our performance even more. In this way, our FX model
# transformation composes with TorchScript with no issues.
jit_rn18 = torch.jit.script(fused_rn18)
print("jit time: ", benchmark(jit_rn18))
with torch.no_grad():
for _ in range(10):
model(inp)
begin = time.time()
for _ in range(iters):
model(inp)
return str(time.time()-begin)

# Benchmark original model
print("Original model time: ", benchmark(rn18))

# Compile with our custom pattern
compiled_with_pattern_matching = torch.compile(rn18, backend="inductor")

# Benchmark compiled model
print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching))


############
# Conclusion
# ----------
# As we can see, using FX we can easily write static graph transformations on
# PyTorch code.
# As we can see, torch.compile provides a powerful way to implement
# graph transformations and optimizations through pattern matching.
# By registering custom patterns, we can extend torch.compile's
# optimization capabilities to handle domain-specific transformations.
#
# Since FX is still in beta, we would be happy to hear any
# feedback you have about using it. Please feel free to use the
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
# you might have.
# The conv-bn fusion demonstrated here is just one example of what's
# possible with torch.compile's pattern matching system.
Loading
Loading