Skip to content

Improve precision for non-power-of-2 scales in brevitas -> QONNX workflow and add pytests #1208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions hls4ml/model/optimizer/passes/quant_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,16 @@ def transform(self, model, node):

rescale = scale
rebias = -bias * scale

# precision of the scale is important for overall model accuracy, so it is increased here
# This is somewhat stupid and needs a better solution
frac_bits = node.get_attr('bitwidth') * 2
scale_precision, scale_quantizer = _calculate_precision_quantizer(frac_bits, 0, signed, narrow, rounding_mode)

attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape)
attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape)
attributes_rescale['scale_quantizer'] = scale_quantizer
attributes_rescale['scale_precision'] = scale_precision

rescale_node = model.make_node(
ApplyAlpha, rescale_name, attributes_rescale, [x for x in node.inputs], [x for x in node.outputs]
Expand Down
78 changes: 78 additions & 0 deletions test/pytest/test_qonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@
import urllib
from pathlib import Path

# To test workflow from brevitas
import brevitas.nn as qnn
import numpy as np
import pytest
import qonnx.core.onnx_exec as oxe
import qonnx.util.cleanup
import qonnx.util.to_channels_last
import torch
from brevitas.export import export_qonnx
from brevitas.quant import (
Int8ActPerTensorFixedPoint,
Int8ActPerTensorFloat,
Int8WeightPerTensorFixedPoint,
Int8WeightPerTensorFloat,
)

# To conveniently run QONNX inference
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
from torch.nn import Module

import hls4ml

Expand Down Expand Up @@ -432,3 +443,70 @@ def test_simple_model(model_name, io_type, backend, request):
y_hls4ml = hls_model.predict(X)

np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1)


# Test brevitas -> QONNX -> hls4ml workflow
quants = {
'Int8WeightPerTensorFloat': Int8WeightPerTensorFloat,
'Int8WeightPerTensorFixedPoint': Int8WeightPerTensorFixedPoint,
'Int8ActPerTensorFloat': Int8ActPerTensorFloat,
'Int8ActPerTensorFixedPoint': Int8ActPerTensorFixedPoint,
}


class QuantModelLinear(Module):
def __init__(self, weight_quant, act_quant):
super().__init__()
self.lin1 = qnn.QuantLinear(4, 4, bias=True, weight_quant=quants[weight_quant], input_quant=quants[act_quant])
self.relu1 = qnn.QuantReLU(act_quant=quants[act_quant])

def forward(self, x):
out = self.relu1(self.lin1(x))
return out


backend = 'Vivado'
io_type = 'io_parallel'


# FixedPoint will give power-of-2 quantization scales, Float non-power-of-2
@pytest.mark.parametrize('backend', ['Vitis'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
@pytest.mark.parametrize('quant_type', ['Float', 'FixedPoint'])
def test_brevitas_workflow(backend, io_type, quant_type):

weight_quant = f'Int8WeightPerTensor{quant_type}'
act_quant = f'Int8ActPerTensor{quant_type}'

model = QuantModelLinear(weight_quant, act_quant)

x = torch.rand(1, 4)

output_path = 'brevitas_onnx.onnx'
_ = export_qonnx(model, input_t=x, export_path=output_path)

modelQONNX = ModelWrapper('brevitas_onnx.onnx')
modelQONNX = qonnx.util.cleanup.cleanup_model(modelQONNX)
modelQONNX = modelQONNX.transform(ConvertToChannelsLastAndClean())
modelQONNX = modelQONNX.transform(GemmToMatMul())
modelQONNX = qonnx.util.cleanup.cleanup_model(modelQONNX)

pytorch_prediction = model(x).detach().numpy()

configQONNX = hls4ml.utils.config.config_from_onnx_model(
modelQONNX, granularity='name', backend=backend, default_precision='fixed<16,6>'
)
# modify the config as desired
hls_modelQONNX = hls4ml.converters.convert_from_onnx_model(
modelQONNX,
output_dir=str(test_root_path / f'hls4mlprj_onnx_brevitas_{quant_type.lower()}_{io_type}_{backend}'),
io_type=io_type,
backend=backend,
hls_config=configQONNX,
)
print(hls_modelQONNX.output_vars)
hls_modelQONNX.compile()

hls_predictionQONNX = np.reshape(hls_modelQONNX.predict(x.detach().numpy()), pytorch_prediction.shape)

np.testing.assert_allclose(pytorch_prediction, hls_predictionQONNX, rtol=0.0, atol=0.05)
Loading