Skip to content

Add LayerNorm support for Vivado #1110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 77 commits into from
Aug 5, 2025

Conversation

rianbrooksflynn
Copy link
Contributor

@rianbrooksflynn rianbrooksflynn commented Nov 4, 2024

Description

This PR adds support for Layer Normalization using either Keras or PyTorch with the Vivado backend in io_parallel mode.

This implementation uses a lookup table for inverse square root; the inputs to the lookup table follow a logarithmic distribution for better accuracy.

Tests have been added for both Keras and Pytorch parsing.

Credit is due to @Ethan0Jiang and @LostEcho365 (Zhixing Jiang and Dennis Yin) for their Vivado implementation and Keras parsing support; my contributions were making a change to the inverse square root lookup table implementation, implementing PyTorch support, and adding unit tests. (Here's a link to their pre-print.) The original code authors have given permission for their code to be merged into hls4ml.

Linked issue: #1109

Type of change

  • New feature (non-breaking change which adds functionality)
  • A new research paper code implementation

Tests

Two unit tests added: test/pytest/test_layernorm.py and test/pytest/test_layernorm_pytorch.py

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@The-Padi
Copy link

The-Padi commented Jun 5, 2025

Any update on the merge ?

@vloncar
Copy link
Contributor

vloncar commented Jun 5, 2025

First in line on my PR review TODO list, I expect early next week to have time for this.

@The-Padi
Copy link

The-Padi commented Jun 5, 2025

First in line on my PR review TODO list, I expect early next week to have time for this.

Thank you very much !

@JanFSchulte
Copy link
Contributor

pre-commit.ci autofix

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jul 1, 2025
@JanFSchulte
Copy link
Contributor

@vloncar I hope this goes in the direction of what you had in mind for the performance validation. I ran synthesis in Vitis 2023.1, 2024.1, and 2025.1 for different input sizes to the LayerNorm and plotted FFs, LUTs, DSPs, BRAM, latency, and II as a function of that input size. 2024.1 and 2025.1 are basically identical, whereas 2023.1 uses a bit less resources but has worse latency.

This is the the default ap_fixed<16.6> and the default target part.

layerNorm_FFs layerNorm_LUTs layerNorm_DSPs layerNorm_BRAMs layerNorm_lats layerNorm_IIs

@vloncar
Copy link
Contributor

vloncar commented Aug 5, 2025

Thanks, looks good. Do all reports say the timing is met? (No scheduling warnings etc, the clock uncertainty is met etc)

@JanFSchulte
Copy link
Contributor

I did not observe any warning about the scheduling. The timing results look all very similar:
image

static const unsigned dim = CONFIG_T::n_in / CONFIG_T::seq_len;
data_T in_val[dim];
res_T outval[dim];
// Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not. I think this can be removed from new code.

hls_model = hls4ml.converters.convert_from_keras_model(
custom_epsilon_model, backend=backend, hls_config=custom_config, io_type='io_parallel', output_dir=output_dir
)
hls_model.compile()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test would be faster if we used hls_model.write(), or completely skip the step of writing/linking. We don't use it here. The later accuracy test checks if the produced code is compilable

# Predict
y_keras = model.predict(data).flatten()
y_hls = hls_model.predict(data).flatten()
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is atol a global variable?


if not ((len(input_shapes[0])) == 3):
raise Exception(
'input size is not currently supported by hls4ml; '
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to say Input shape <some shape> is not supported, only ...

@@ -44,6 +44,25 @@ def transform(self, model, node):
node.get_output_variable().shape = input_shape
dim_names = [f'N_INPUT_{i}_{node.index}' for i in range(1, len(input_shape) + 1)]
node.get_output_variable().dim_names = dim_names
elif (
isinstance(node, LayerNormalization)
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "off"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the proposed change in #1352 we'll never get to off check here.

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Aug 5, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Aug 5, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Aug 5, 2025
@vloncar vloncar merged commit fd41dc5 into fastmachinelearning:main Aug 5, 2025
5 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants