Skip to content

Automatic type inference for param_t in Parametrised Activations #1139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

nghielme
Copy link
Contributor

@nghielme nghielme commented Dec 4, 2024

This small PR implement the inference of W and I parameter for a given floating point constant. It is exploited in parametrised activations

Type of change

  • New feature (non-breaking change which adds functionality)

Tests

I run some tests related to Parametrised Activations, already present in the pytests of hls4ml.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@nghielme nghielme requested a review from jmitrevs December 5, 2024 06:53
@nghielme nghielme added the please test Trigger testing by creating local PR branch label Dec 5, 2024
@nghielme
Copy link
Contributor Author

nghielme commented Dec 9, 2024

I see some tests related to oneAPI fails; it's hard to me to understand why they fail, how should I proceed?

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Dec 16, 2024
@jmitrevs
Copy link
Contributor

If you have a linux setup it should be pretty straightforward to install oneAPI, and then you can run the pytest. But we can wait to look at the other issues first. Maybe it will clear itself.

@jmitrevs
Copy link
Contributor

@JanFSchulte
Copy link
Contributor

I wanted to try to install oneAPI myself, so I played with this PR a bit. The issue seems to be that the precision for the parameter of the leaky ReLU is reduced significantly, from typedef ac_fixed<16,6,true> quantizer_param_t; to a one-bit typedef ac_fixed<1,0,false> quantizer_param_t;. Vivado and the other backends seem to be able to handle it, but I'm not sure it makes sense in this case because we have negative slopes here and need it to be signed. The other backends seem to be able to deal with it. But for oneAPI, a signed variable is enforced to have at least two bits:

signed _BitInt must have a bit size of at least 2

So we need to make sure to take this into account when inferring the precision for the parameters.

@bo3z bo3z added this to the v1.1.0 milestone Mar 7, 2025
@bo3z bo3z modified the milestones: v1.1.0, v1.2.0 Apr 8, 2025
@JanFSchulte
Copy link
Contributor

Hey @nghielme any news on this one?

@nghielme
Copy link
Contributor Author

nghielme commented Jun 6, 2025

I'll take a look soon

@jmitrevs jmitrevs marked this pull request as draft July 24, 2025 16:47
@jmitrevs jmitrevs marked this pull request as ready for review July 26, 2025 01:17
@jmitrevs
Copy link
Contributor

Please check the logic of the precision setting.

@jmitrevs
Copy link
Contributor

I added a unit test to cover the various options, so I am more confident. It did discover an error in the max setting for unsigned FixedPrecisionType, which I fixed, and am including here, though it's logically independent.

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jul 29, 2025
else:
# find a constant to represent the values
param = node.get_attr('activ_param')
precision = _get_precision_from_constant(param)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this correctly, we are basically hard-coding the bit width of the parameter to be 8 (9 if signed) and assign the fractional and integer bits based on the value. Is that correct? Because trying to find a way to infer the needed total precision has been something that has stumped me forever when working on the brevitas stuff, but it seems that here as well the only solution is to hardcode some arbitrary value.

Copy link
Contributor

@jmitrevs jmitrevs Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's mainly for those where the cases above don't apply. For 0 values we just use 1 bit. For power of 2 values we use a width of 1 or 2, depending on whether is negative or not. Then comes the attempt to use Fxp from fxpmath, which is logically like struct. It works well for values like 1.25 of things that can be represented exactly. In those cases, the optimizer uses the width from Fxp. But if that produces a width larger than 8 (not including the sign bit), then the size is capped at 8, with the appropriate range being set by the integer size. Note that Fxp would otherwise attempt to use 56 bits to store 1.1. These we cut off at 8 bits.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks. I think I will then implement something similar to the non-power-of-2 cases for brevitas.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimizer can set negative integer bitwidths if it needs the precision for smaller values.

@JanFSchulte
Copy link
Contributor

There were some weird pytest failures that I'm rerunning, but otherwise I think this can be merged now.

@nghielme
Copy link
Contributor Author

Looks good to me. One small note, I think the test could be rewritten in a more pytest way, like this:

@pytest.mark.parametrize(
    "val, expected_width",
    [
        (0, 1),
        (-1024, 2),
        (1024, 1),
        (0.03125, 1),
        (-0.03125, 2),
        (1.25, 3),
        (-1.25, 4),
        (1.1, 8),
        (-1.1, 9),
    ]
)
def test_precision_from_constant_unit(val, expected_width):
    """Test determining precision needed for a constant."""
    max_width = 8
    fp = _get_precision_from_constant(val, max_width)

    assert fp.min <= val <= fp.max
    assert fp.width == expected_width
    assert fp.signed == (val < 0)

    quantum = 2.0 ** -fp.fractional
    if expected_width < max_width:
        assert val % quantum == 0

@JanFSchulte
Copy link
Contributor

Tests have only the "expected" failures now, so I think this is ok. I agree with Nicolo's comment on the pytest though, so if you could integrate that before merging that would be great, Jovan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants