Open
Description
Feature Support Request
I would like to fuse ConvTranspose2d+BatchNorm2d+ReLU
AssertionError: did not find fuser method for: (<class 'torch.nn.modules.conv.ConvTranspose2d'>, <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, <class 'torch.nn.modules.activation.ReLU'>)
Reproduce
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
python3 fuse.py
fuse.py
import torch
import torch.nn as nn
class DummyModel(nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
self.conv = nn.ConvTranspose2d(3, 32, 3, 1)
self.bn = nn.BatchNorm2d(32)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
modules_to_fuse = ["conv", "bn", "relu"]
model = DummyModel()
model.eval()
model = torch.ao.quantization.fuse_modules(model, modules_to_fuse)