Skip to content

Support for parsing ONNX Pad node #1352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions hls4ml/converters/onnx/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,68 @@ def parse_resize_layer(node, input_names, input_shapes, graph):
)

return layer


@onnx_handler('Pad')
def parse_pad_layer(node, input_names, input_shapes, graph):
layer = {}
layer['name'] = node.name
layer['class_name'] = 'ZeroPadding'
layer['inputs'] = input_names
layer['outputs'] = list(node.output)
layer['data_format'] = (
'channels_last' if any(node.domain == 'qonnx.custom_op.channels_last' for node in graph.node) else 'channels_first'
)

mode = get_onnx_attribute(node, 'mode')
if mode is not None and mode != 'constant':
raise RuntimeError(f'Unsupported padding mode: {mode} in node {node.name}')

pads = get_onnx_attribute(node, 'pads')

dim = 0
if len(input_shapes[0]) == 3:
dim = 1 # 2D input (batch, channels, width), will use ZeroPadding1D
if layer['data_format'] == 'channels_first':
_, channels, width = input_shapes[0]
pad_left, pad_right = pads[2], pads[5]
else:
_, width, channels = input_shapes[0]
pad_left, pad_right = pads[1], pads[4]
out_width = width + pad_left + pad_right

layer['n_chan'] = channels
layer['in_width'] = width
layer['out_width'] = out_width

layer['pad_left'] = pad_left
layer['pad_right'] = pad_right
elif len(input_shapes[0]) == 4:
dim = 2 # 3D input (batch, channels, height, width), will use ZeroPadding2D
if layer['data_format'] == 'channels_first':
_, channels, height, width = input_shapes[0]
pad_top, pad_bottom = pads[2], pads[6]
pad_left, pad_right = pads[3], pads[7]
else:
_, height, width, channels = input_shapes[0]
pad_top, pad_bottom = pads[1], pads[5]
pad_left, pad_right = pads[2], pads[6]
out_height = height + pad_top + pad_bottom
out_width = width + pad_left + pad_right

layer['n_chan'] = channels
layer['in_height'] = height
layer['in_width'] = width
layer['out_height'] = out_height
layer['out_width'] = out_width

layer['pad_top'] = pad_top
layer['pad_bottom'] = pad_bottom
layer['pad_left'] = pad_left
layer['pad_right'] = pad_right
else:
raise RuntimeError(f'Unsupported input shape: {input_shapes[0]} for Pad node {node.name}')

layer['class_name'] += str(dim) + 'D'

return layer
4 changes: 4 additions & 0 deletions hls4ml/converters/pytorch/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def parse_constantpad2d_layer(operation, layer_name, input_names, input_shapes,
layer['out_height'] = out_height
layer['out_width'] = out_width

layer['data_format'] = 'channels_first' # Default data format in PyTorch

return layer, output_shape


Expand Down Expand Up @@ -243,4 +245,6 @@ def parse_constantpad1d_layer(operation, layer_name, input_names, input_shapes,
layer['in_width'] = width
layer['out_width'] = out_width

layer['data_format'] = 'channels_first' # Default data format in PyTorch

return layer, output_shape
44 changes: 0 additions & 44 deletions test/pytest/test_pytorch_constpadmapping.py

This file was deleted.

86 changes: 86 additions & 0 deletions test/pytest/test_zeropadding_pytorch_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pathlib import Path

import numpy as np
import qonnx.util.cleanup
import torch
import torch.nn as nn
from qonnx.core.modelwrapper import ModelWrapper

from hls4ml.converters import convert_from_onnx_model, convert_from_pytorch_model
from hls4ml.utils.config import config_from_onnx_model, config_from_pytorch_model

test_root_path = Path(__file__).parent


def test_constantpad_1d():
class Pad1DModel(nn.Module):
def __init__(self):
super().__init__()
self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right

def forward(self, x):
return self.pad(x)

model = Pad1DModel()
model.eval()
config_pytorch = config_from_pytorch_model(model, (2, 4), channels_last_conversion='off')
hls_model_pytorch = convert_from_pytorch_model(
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/pytorch'), hls_config=config_pytorch
)

hls_model_pytorch.compile()

onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx')
torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, dynamo=True)
qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path)
pad1d_onnx = ModelWrapper(onnx_path)

config_onnx = config_from_onnx_model(pad1d_onnx)
hls_model_onnx = convert_from_onnx_model(
pad1d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/onnx'), hls_config=config_onnx
)

hls_model_onnx.compile()

input_data = np.random.randn(10, 2, 4)
pred_pytorch = hls_model_pytorch.predict(input_data)
pred_onnx = hls_model_onnx.predict(input_data)

np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5)


def test_constantpad_2d():
class Pad2DModel(nn.Module):
def __init__(self):
super().__init__()
self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom

def forward(self, x):
return self.pad(x)

model = Pad2DModel()
model.eval()
config_pytorch = config_from_pytorch_model(model, (2, 3, 4), channels_last_conversion='off')
hls_model_pytorch = convert_from_pytorch_model(
model, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/pytorch'), hls_config=config_pytorch
)

hls_model_pytorch.compile()

onnx_path = str(test_root_path / 'hls4mlprj_constpad_2d/pad2d.onnx')
torch.onnx.export(model, torch.randn(1, 2, 3, 4), onnx_path, dynamo=True)
qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path)
pad2d_onnx = ModelWrapper(onnx_path)

config_onnx = config_from_onnx_model(pad2d_onnx)
hls_model_onnx = convert_from_onnx_model(
pad2d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/onnx'), hls_config=config_onnx
)

hls_model_onnx.compile()

input_data = np.random.randn(10, 2, 3, 4)
pred_pytorch = hls_model_pytorch.predict(input_data)
pred_onnx = hls_model_onnx.predict(input_data)

np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5)
Loading