Skip to content

Support for parsing ONNX Pad node #1352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

vloncar
Copy link
Contributor

@vloncar vloncar commented Aug 4, 2025

Description

Followup to #1322 to add support for Pad op in (Q)ONNX.

I cleaned up the test so that it checks both PyTorch and ONNX. Additionally, the channels last converter seem to ignore the setting "off", so I modified that too. May need an additional check.

Type of change

  • Bug fix (non-breaking change that fixes an issue) - The change in handling ChannelsLastConversion = off
  • New feature (non-breaking change which adds functionality) - Support for ONNX Pad op

Tests

Tests are included in test_zeropadding_pytorch_onnx.py

@vloncar vloncar requested review from jmitrevs and JanFSchulte August 4, 2025 18:54
@@ -13,8 +13,9 @@ class ChannelsLastConverter(OptimizerPass):

def match(self, node):
# If this parameter has not been set, this model does not need to be converted
if 'ChannelsLastConversion' not in node.model.config.config['HLSConfig']['Model']:
return False # No littering of unused property
do_convert = node.model.config.config['HLSConfig']['Model'].get('ChannelsLastConversion', 'off')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JanFSchulte Can you check this? I found that if I set ChannelsLastConversion to off as instructed by the docs, nothing happens, and we still get into this optimizer and the node ends up with a change to data_format = channels_last.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that seems like a bug, thanks for catching. Not quite sure how that happened since it looks like I implemented the switches in the config without ever propagating the 'off' setting to the optimizer. But this change looks good to me.

hls_model_pytorch.compile()

onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx')
torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, opset_version=10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the now recommended usage now with dynamo=True? I have previously followed the recipe from https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html successfully, though I can't guarantee that it would work here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that the docs now push dynamo=True a lot, but wasn't sure what's our stance on it. I can change if that's what you prefer? I realize now that I have to change this line anyway, opset_version=10 is a workaround for my local env, it should not be in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the docs, they do push it pretty heavily indeed, so I'd be in favor of going with the dynamo=True option just to be sure we support the recommended usage.

@vloncar vloncar mentioned this pull request Aug 5, 2025
8 tasks
@vloncar vloncar added the please test Trigger testing by creating local PR branch label Aug 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants