Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d093913
Added support for BipolarQuant. Its converted to BinaryQuant in hls4ml.
Apr 28, 2025
4b66180
Added binarized qonnx model for testing the binary quant transformation
May 7, 2025
721d598
Pre-commit fixes
May 7, 2025
768c6a9
Merge branch 'main' into qonnx_binary_quant
jurevreca12 May 7, 2025
6a74bfb
Removed BipolarQuantConstantParameters, because such an optimization …
May 8, 2025
10e1af0
Merge branch 'qonnx_binary_quant' of https://github.com/jurevreca12/h…
May 8, 2025
2dfdb25
Limited FuseBipolarQuantWithConstant to only support scale factors of 1
May 8, 2025
89e2136
Removed bipolar_quant_constant_parameters from list of optimizations,…
May 8, 2025
76968b5
Modified the optimizations to only consider transform when scaling fa…
May 8, 2025
8a20361
Removed left-over docs from copying.
May 8, 2025
7bd4d94
Revert "Removed BipolarQuantConstantParameters"
May 9, 2025
144d427
Revert "Removed bipolar_quant_constant_parameters from list of optimi…
May 9, 2025
8d6aae2
Removed onnx model form repo. Using example-models for that instead.
May 9, 2025
18ce38e
Added test for non-unit (po2) scaling factors
May 13, 2025
08dafdf
Pre-commit fixes.
May 13, 2025
63feffc
Merge branch 'main' into qonnx_binary_quant
jmitrevs Jun 11, 2025
2649762
Merge branch 'main' into qonnx_binary_quant
jmitrevs Jul 28, 2025
4131325
add IntQuant parsing support
jmitrevs Jul 29, 2025
85edc62
don't merge BN to Dense, etc, if binary weights
jmitrevs Jul 31, 2025
9e6f800
change some things that assume name == output[0], and reoder so that …
jmitrevs Aug 2, 2025
cfd942d
snapshot of progress towards bipolar quant
jmitrevs Aug 2, 2025
b6b3751
Merge remote-tracking branch 'upstream/main' into qonnx_binary_quant_dev
jmitrevs Aug 4, 2025
fa3e407
fix pre-commit errors
jmitrevs Aug 4, 2025
c17a665
fix overwriting of precision attribute, fix cast for Vitis
jmitrevs Aug 5, 2025
6d2c18a
update bipolar quant for nonunitary scale
jmitrevs Aug 5, 2025
f73414c
make initial versions for Vitis streaming, oneAPI, and Catapult. (Vit…
jmitrevs Aug 6, 2025
966adff
change ap_int to ac_int for oneAPI and Catapult
jmitrevs Aug 6, 2025
ca6d319
unify the comparisons with 0; add more tests
jmitrevs Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ docs/autodoc/*
hls4mlprj_*
*~
*.ipynb_checkpoints/
*.bak
2 changes: 1 addition & 1 deletion example-models
13 changes: 12 additions & 1 deletion hls4ml/converters/onnx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def parse_batchnorm_layer(node, input_names, input_shapes, graph):
return layer


@onnx_handler('Quant')
@onnx_handler('Quant', 'IntQuant')
def parse_quant_layer(node, input_names, input_shapes, graph):
layer = {}

Expand All @@ -120,3 +120,14 @@ def parse_quant_layer(node, input_names, input_shapes, graph):
layer['signed'] = bool(get_onnx_attribute(node, 'signed'))

return layer


@onnx_handler('BipolarQuant')
def parse_bipolar_quant_layer(node, input_names, input_shapes, graph):
layer = {}

layer['class_name'] = 'BipolarQuant'
layer['name'] = node.name
layer['inputs'] = input_names
layer['outputs'] = list(node.output)
return layer
18 changes: 10 additions & 8 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,11 @@ def replace_node(self, old_node, new_node):
repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node.outputs)}
repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node.inputs)})

for old_output in old_node.outputs:
if old_output in self.outputs:
new_output = repl[old_output]
self.outputs = [new_output if name == old_output else name for name in self.outputs]

for node in self.graph.values():
for i, n in enumerate(node.inputs):
if n in repl:
Expand All @@ -703,11 +708,6 @@ def replace_node(self, old_node, new_node):

self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items())

old_name = old_node.name
if old_name in self.outputs:
new_name = new_node.name
self.outputs = [new_name if name == old_name else name for name in self.outputs]

def split_node(self, old_node, new_node1, new_node2):
"""Replace an existing node in the graph with two nodes in sequence.

Expand All @@ -728,6 +728,11 @@ def split_node(self, old_node, new_node1, new_node2):
repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node2.outputs)}
repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node1.inputs)})

for old_output in old_node.outputs:
if old_output in self.outputs:
new_output = repl[old_output]
self.outputs = [new_output if name == old_output else name for name in self.outputs]

for node in self.graph.values():
for i, n in enumerate(node.inputs):
if n in repl:
Expand All @@ -745,9 +750,6 @@ def split_node(self, old_node, new_node1, new_node2):
new_graph[key] = value
self.graph = new_graph

if old_node.name in self.outputs:
self.outputs = [new_node2.name if name == old_node.name else name for name in self.outputs]

def next_layer(self):
self.index += 1
return self.index
Expand Down
19 changes: 18 additions & 1 deletion hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,21 @@ def initialize(self):
self.add_output_variable(shape, dims)


class BipolarQuant(Layer): # The QONNX quantization layer
"""
This is a QONNX quantization layer. Optimizations should convert it
before HLS is produced.
"""

_expected_attributes = []

def initialize(self):
inp = self.get_input_variable(self.inputs[0])
shape = inp.shape
dims = inp.dim_names
self.add_output_variable(shape, dims)


class Reshape(Layer):
_expected_attributes = [
Attribute('target_shape', value_type=typing.Sequence),
Expand Down Expand Up @@ -977,7 +992,7 @@ def initialize(self):
inp = self.get_input_variable()
shape = inp.shape
dims = inp.dim_names
self.add_output_variable(shape, dims)
self.add_output_variable(shape, dims, precision=self.get_attr('quantizer_precision')) # for xor precision
if 'n_in' not in self.attributes:
self.set_attr('n_in', self.get_input_variable().size())

Expand Down Expand Up @@ -1898,6 +1913,8 @@ def initialize(self):
'GarNet': GarNet,
'GarNetStack': GarNetStack,
'Quant': Quant,
'IntQuant': Quant,
'BipolarQuant': BipolarQuant,
'ApplyAlpha': ApplyAlpha,
'BatchNormOnnx': BatchNormOnnx,
'LayerGroup': LayerGroup,
Expand Down
7 changes: 6 additions & 1 deletion hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@
'reshape_constant',
'resize_remove_constants',
'quant_constant_parameters',
'quant_to_activation',
'bipolar_quant_constant_parameters',
'fuse_quant_with_constant',
'fuse_bipolar_quant_with_constant',
'quant_to_activation',
'bipolar_quant_to_activation',
'const_quant_to_const_alpha',
'const_bipolar_quant_to_const_alpha',
'quant_to_alpha_activation_alpha',
'bipolar_quant_to_alpha_activation_alpha',
'batch_norm_onnx_constant_parameters',
'constant_batch_norm_fusion',
'merge_two_constants',
Expand Down
263 changes: 263 additions & 0 deletions hls4ml/model/optimizer/passes/bipolar_quant_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""
This file includes optimizations related to BipolarQuant nodes.

"""

import copy

import numpy as np

from hls4ml.model.layers import Activation, ApplyAlpha, BipolarQuant, Constant
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.quantizers import BinaryQuantizer
from hls4ml.model.types import XnorPrecisionType


class BipolarQuantConstantParameters(OptimizerPass):
"""Remove Constant from the BipolarQaunt node parameters (but not input[0])"""

def match(self, node):
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 2
and (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant))
)

return is_match

def transform(self, model, node):
"""
Remove Constant from the BipolarQuant node parameters (but not input[0])
"""
if node.get_input_node(node.inputs[1]):
scale_node = node.get_input_node(node.inputs[1])
if isinstance(scale_node, Constant):
node.set_attr('scale', scale_node.get_attr('value'))
node.inputs[1] = ''
model.remove_node(scale_node)

node.inputs = [inp for inp in node.inputs if inp]
if len(node.inputs) != 1:
raise RuntimeError("hls4ml only supports constant scale")

return True


class BipolarQuantToActivation(OptimizerPass):
"""
This is for the case when scale is 1. It is a a 1:1 transformation of a BipolarQuant to an Activation.
This is not called when the input is constant.
"""

def match(self, node):
# only matches after the other inputs are already folded
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 1
and not isinstance(node.get_input_node(node.inputs[0]), Constant)
)

# Only match if the scale is 1
if is_match: # to make sure this is a quant node with inputs
scale = node.get_attr('scale')
is_match = (scale == 1.0).all()

return is_match

def transform(self, model, node):
"""
Change BipolarQuant node to Activation
"""
precision = XnorPrecisionType()
quantizer = BinaryQuantizer(bits=1)

attributes = {'activation': 'binary_tanh', 'quantizer': quantizer, 'quantizer_precision': precision}

# update the configuration (not setting the precision since can't specify xnor type)
config = model.config.get_layer_config(node)
new_name = f'{node.name}_act'
model.config.set_name_config(new_name, config)
model.config.parse_name_config(new_name, config)

new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], list(node.outputs))
model.replace_node(node, new_node)
return True


class FuseBipolarQuantWithConstant(OptimizerPass):
"""
This is for the case when scale is 1 and the input is a constant
"""

def match(self, node):

# only matches after the other inputs are already folded
# and scale is unit
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 1
and isinstance(node.get_input_node(node.inputs[0]), Constant)
)

# Only match if the scale is 1
if is_match: # to make sure this is a quant node with inputs
scale = node.get_attr('scale')
is_match = (scale == 1.0).all()

return is_match

def transform(self, model, node):
"""
Fuse BipolarQuant with Constant.
"""
precision = XnorPrecisionType()
quantizer = BinaryQuantizer(bits=1)

const_node = node.get_input_node(node.inputs[0])
const_node.set_attr('quantizer', quantizer)
const_node.get_output_variable().type.precision = precision

# remove the Quant node
model.remove_node(node)
return True


class BipolarQuantToAlphaActivationAlpha(OptimizerPass):
"""
This is for the case when scale is not 1. It is a a 1:3 transformation of
a BipolarQuant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to rescale).

NOTE: It needs to be scheduled after BipolarQuantToActivation (or we need to make the match criteria stricter)
"""

def match(self, node):
# only matches after the other inputs are already folded
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 1
and not isinstance(node.get_input_node(node.inputs[0]), Constant)
)
return is_match

def transform(self, model, node):
"""
Change quant node to ApplyAlhpa, Activation, ApplyAlpha
"""

# Do the Activation as in the simple case

precision = XnorPrecisionType()
quantizer = BinaryQuantizer(bits=1)

activation_attributes = {'activation': 'binary_tanh', 'quantizer': quantizer, 'quantizer_precision': precision}

# update the configuration (not setting the precision since can't specify xnor type)
config = model.config.get_layer_config(node)
act_config = copy.deepcopy(config)
act_name = f'{node.name}_act'
model.config.set_name_config(act_name, act_config)
model.config.parse_name_config(act_name, act_config)

new_node = model.make_node(Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs])
model.replace_node(node, new_node)

# but now add the ApplyAlhpas before and after

inshape = node.get_input_variable().shape

scale = node.get_attr('scale')
bias = np.array(0)

attributes_scale = {'n_filt': -1}
attributes_rescale = {'n_filt': -1}

scale_config = copy.deepcopy(config)
scale_name = f'{node.name}_scale'
model.config.set_name_config(scale_name, scale_config)
model.config.parse_name_config(scale_name, scale_config)

rescale_config = config # no need to deep copy the last
rescale_name = f'{node.name}_rescale'
model.config.set_name_config(rescale_name, rescale_config)
model.config.parse_name_config(rescale_name, rescale_config)

firstscale = 1 / scale
firstbias = bias
attributes_scale['scale_data'] = np.broadcast_to(firstscale, inshape)
attributes_scale['bias_data'] = np.broadcast_to(firstbias, inshape)

scale_node = model.make_node(ApplyAlpha, scale_name, attributes_scale, [node.inputs[0]])
model.insert_node(scale_node)

rescale = scale
rebias = -bias * scale
attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape)
attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape)

rescale_node = model.make_node(ApplyAlpha, rescale_name, attributes_rescale, [new_node.outputs[0]])
model.insert_node(rescale_node)

return True


class ConstBipolarQuantToConstAlpha(OptimizerPass):
"""
This is for the case when scale is not 1. It is a a 1:3 transformation of
a BipolarQuant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to unscale), but an input
consts allows for optimization, so the ApplyAlpha (to scale), Activation are
optimized away right away.
"""

def match(self, node):
# only matches after the other inputs are already folded
is_match = (
isinstance(node, BipolarQuant)
and len(node.inputs) == 1
and isinstance(node.get_input_node(node.inputs[0]), Constant)
)

if is_match: # to make sure this is a quant node with inputs
scale = node.get_attr('scale')
is_match = is_match and ((scale != np.ones_like(scale)).any())
return is_match

def transform(self, model, node):
"""
Change Constant + Quant node to Constant, ApplyAlpha
"""

precision = XnorPrecisionType()
quantizer = BinaryQuantizer(bits=1)

const_node = node.get_input_node(node.inputs[0])

scale = node.get_attr('scale')
bias = np.array(0) # zeropt not defined for bipolar quants

# caclucate the new value
new_val = const_node.get_attr('value') / scale + bias
const_node.set_attr('value', new_val)
const_node.set_attr('quantizer', quantizer)

const_node.get_output_variable().type.precision = precision

inshape = node.get_input_variable().shape

attributes_rescale = {'n_filt': -1}

rescale_config = copy.deepcopy(model.config.get_layer_config(node))
rescale_name = f'{node.name}_rescale'
model.config.set_name_config(rescale_name, rescale_config)
model.config.parse_name_config(rescale_name, rescale_config)

rescale = scale
rebias = -bias * scale
attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape)
attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape)

rescale_node = model.make_node(
ApplyAlpha, rescale_name, attributes_rescale, [x for x in node.inputs], [x for x in node.outputs]
)
model.replace_node(node, rescale_node)

return True
Loading
Loading