-
Couldn't load subscription status.
- Fork 87
Open
Description
Hi,
I found this incompatibility while trying to find the source of our failing CI. It seems something has changed* in onnxscript==0.5.4 which means that export of BatchNorm2d fails. Due to the existence of this file I'm expecting onnxscript==0.5.4 should** be compatible with PyTorch==2.6.0. The issue doesn't seem to occur with PyTorch>2.6.
The minimal reproducer is as follows:
import torch
class MinimalBn(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(8)
def forward(self, x):
return self.bn(x)
x = torch.rand(1,8,2,2)
model = MinimalBn()
output = model(x) # Make sure the forward path is traceable
model.eval()
with torch.no_grad():
torch.onnx.export(
model,
(x),
"minimal.onnx",
input_names=['x'],
dynamo=True,
)which produces the following error:
<class 'torch.onnx._internal.exporter._errors.ConversionError'>: Error when translating node %_native_batch_norm_legit_no_training : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%x, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, 0.1, 1e-05), kwargs = {}). See the stack trace for more information.
The full log is here. FYI, I'm using:
PyTorch==2.6.0onnxscript==0.5.4python=3.11.14
*I see no issue with onnxscript==0.5.3.
**Let me know if I'm mistaken, and I'll recommend that PyTorch make a patch release which enforces onnxscript<=0.5.3 with PyTorch==2.6.x.
Metadata
Metadata
Assignees
Labels
No labels