Skip to content

PQuant🌶️ integration #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
BatchNormalization,
Dense,
HardActivation,
MultiplierReLU,
LayerNormalization,
ParametrizedActivation,
PReLU,
Expand Down Expand Up @@ -268,7 +269,7 @@ def format(self, node):

class ParamActivationConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__((ParametrizedActivation, PReLU))
super().__init__((ParametrizedActivation, PReLU, MultiplierReLU))
self.template = param_activ_config_template

def format(self, node):
Expand Down Expand Up @@ -381,7 +382,7 @@ def format(self, node):

class PReLUFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(PReLU, include_header=activ_include_list)
super().__init__((PReLU, MultiplierReLU), include_header=activ_include_list)
self.template = param_activ_function_template

def format(self, node):
Expand Down
1 change: 1 addition & 0 deletions hls4ml/converters/keras_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import hgq2 # noqa: F401
from . import merge # noqa: F401
from . import pooling # noqa: F401
from . import pquant # noqa: F401
from . import recurrent # noqa: F401
from ._base import registry as layer_handlers

Expand Down
3 changes: 3 additions & 0 deletions hls4ml/converters/keras_v3/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,7 @@ def handle(
elif isinstance(layer, BaseConv):
config['weight_data'] = kernel

if hasattr(layer, 'quantization_parameters'):
config['quantization_parameters'] = layer.quantization_parameters

return config
4 changes: 4 additions & 0 deletions hls4ml/converters/keras_v3/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def handle(
'n_out': n_out,
'n_in': n_in,
}

if hasattr(layer, 'quantization_parameters'):
config['quantization_parameters'] = layer.quantization_parameters

return config


Expand Down
127 changes: 127 additions & 0 deletions hls4ml/converters/keras_v3/pquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import typing
from collections.abc import Sequence

import numpy as np

from hls4ml.model.types import FixedPrecisionType

from ._base import KerasV3LayerHandler, register
from .conv import gen_conv_config

if typing.TYPE_CHECKING:
import pquant
from keras import KerasTensor


@register
class PQuantReLUHandler(KerasV3LayerHandler):
handles = ('pquant.core.activations_quantizer.QuantizedReLU',)

def handle(
self,
layer: 'pquant.core.activations_quantizer.QuantizedReLU',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
config = {}
config.update(self.default_config)
config['quantization_parameters'] = layer.quantization_parameters

if (
not config['quantization_parameters']['use_high_granularity_quantization']
and layer.config['quantization_parameters']['use_relu_multiplier']
):
config['class_name'] = 'MultiplierReLU'
config['param_data'] = np.array(layer.multiplier)
config['activation'] = 'multiplier_relu'

else:
config['class_name'] = 'QActivation'
config['activation'] = 'relu'

return (config,)


@register
class PQuantTanhHandler(KerasV3LayerHandler):
handles = ('pquant.core.activations_quantizer.QuantizedTanh',)

def handle(
self,
layer: 'pquant.core.activations_quantizer.QuantizedTanh',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
config = {}
config.update(self.default_config)
config['quantization_parameters'] = layer.quantization_parameters

if not layer.config['quantization_parameters']['use_real_tanh']:
config['class_name'] = 'HardActivation'
config['slope'] = 0.5 # the default values in QKeras
config['shift'] = 0.5
# Quartus seems to have trouble if the width is 1.
config['slope_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
config['shift_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
config['activation'] = 'hard_tanh'

else:
config['class_name'] = 'QActivation'
config['activation'] = 'tanh'

return (config,)


@register
class PQuantPoolingHandler(KerasV3LayerHandler):
handles = ('pquant.core.tf_impl.compressed_layers_tf.QuantizedPooling',)

def handle(
self,
layer: 'pquant.core.tf_impl.compressed_layers_tf.QuantizedPooling',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
assert len(in_tensors) == 1, f'Layer {layer.name} has more than one input'
assert len(out_tensors) == 1, f'Layer {layer.name} has more than one output'

in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore
out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore
assert all(isinstance(x, int) for x in in_shape), f'Layer {layer.name} has non-fixed size input: {in_shape}'
assert all(isinstance(x, int) for x in out_shape), f'Layer {layer.name} has non-fixed size output: {out_shape}'

data_format = layer.data_format

if data_format == 'channels_last':
*px_in_shape, _ = in_shape
else:
_, *px_in_shape = in_shape

pool_size: tuple[int, ...] = layer.pool_size

strides = layer.strides
padding = layer.padding
pooling_config = gen_conv_config(
in_shape=in_shape,
out_shape=out_shape,
ker_px_shape=pool_size,
strides=strides,
data_format=data_format,
padding=padding,
name=layer.name,
)

pooling_config['pool_width'] = pooling_config.pop('filt_width')
if 'filt_height' in pooling_config:
pooling_config['pool_height'] = pooling_config.pop('filt_height')
if len(px_in_shape) == 1:
# inconsistent pooling1d config key name...
pooling_config['n_in'] = pooling_config['in_width']
pooling_config['n_out'] = pooling_config['out_width']

config = {}
config.update(self.default_config)
config.update(pooling_config)
config['class_name'] = f'AveragePooling{layer.dimensions}D'
config['quantization_parameters'] = layer.quantization_parameters
return (config,)
8 changes: 8 additions & 0 deletions hls4ml/converters/pytorch/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c

output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_width']] # Channel first as default

# Quantization parameter for PQuant integration
if hasattr(class_object, "quantization_parameters"):
layer['quantization_parameters'] = class_object.quantization_parameters

return layer, output_shape


Expand Down Expand Up @@ -94,4 +98,8 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c

output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']]

# Quantization parameter for PQuant integration
if hasattr(class_object, "quantization_parameters"):
layer['quantization_parameters'] = class_object.quantization_parameters

return layer, output_shape
4 changes: 4 additions & 0 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c
output_shape = input_shapes[0][:]
output_shape[-1] = layer['n_out']

# Quantization parameter for PQuant integration
if hasattr(class_object, "quantization_parameters"):
layer['quantization_parameters'] = class_object.quantization_parameters

return layer, output_shape


Expand Down
69 changes: 69 additions & 0 deletions hls4ml/converters/pytorch/pquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from hls4ml.converters.pytorch.core import parse_activation_layer
from hls4ml.converters.pytorch.pooling import parse_pooling_layer
from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.model.types import FixedPrecisionType


@pytorch_handler('QuantizedActivationTorchWrapper')
def parse_pquant_activation_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):

layer, output_shape = parse_activation_layer(
class_object.activation.__class__.__name__,
layer_name,
input_names,
input_shapes,
node,
class_object.activation,
data_reader,
config,
)
layer['quantization_parameters'] = class_object.activation.quantization_parameters

if (
layer['activation'] == 'quantizedtanh'
and not class_object.activation.config['quantization_parameters']['use_real_tanh']
):
layer['class_name'] = 'HardActivation'
layer['slope'] = 0.5 # the default values in QKeras
layer['shift'] = 0.5
# Quartus seems to have trouble if the width is 1.
layer['slope_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['shift_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['activation'] = 'hard_tanh'

elif (
layer['activation'] == 'quantizedrelu'
and not layer['quantization_parameters']["use_high_granularity_quantization"]
and class_object.activation.config['quantization_parameters']['use_relu_multiplier']
):
layer['class_name'] = 'MultiplierReLU'
layer['param_data'] = class_object.activation.multiplier.numpy()
layer['activation'] = 'multiplier_relu'

else:
layer['class_name'] = 'QActivation'
activation_map = {
'quantizedrelu': 'relu',
'quantizedtanh': 'tanh',
}
layer['activation'] = activation_map.get(layer['activation'], layer['activation'])

return layer, output_shape


@pytorch_handler('QuantizedPooling')
def parse_pquant_pooling_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):

layer, output_shape = parse_pooling_layer(
class_object.pooling.__class__.__name__,
layer_name,
input_names,
input_shapes,
node,
class_object.pooling,
data_reader,
config,
)
layer['quantization_parameters'] = class_object.quantization_parameters

return layer, output_shape
5 changes: 4 additions & 1 deletion hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ def parse_pytorch_model(config, verbose=True):
if '.' not in node.target:
obj = getattr(model, node.name)
else:
obj = getattr(children[node.target.split('.')[0], node.name])
if '_' not in node.name:
obj = getattr(children[node.target.split('.')[0]], node.name)
else:
obj = getattr(children[node.target.split('.')[0]], node.name.split('_')[1])

input_layer = {}
input_layer['name'] = node.name
Expand Down
15 changes: 15 additions & 0 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,20 @@ def initialize(self):
self.add_weights_variable(name='param', var_name='a{index}')


class MultiplierReLU(Activation):
_expected_attributes = [
Attribute('n_in'),
WeightAttribute('param'),
TypeAttribute('param'),
]

def initialize(self):
super().initialize()
self.add_weights_variable(
name='param', var_name='m{index}', precision=FixedPrecisionType(width=4, integer=4, signed=True)
)


class Softmax(Activation):
def initialize(self):
super().initialize()
Expand Down Expand Up @@ -1770,6 +1784,7 @@ def initialize(self):
'ThresholdedReLU': ParametrizedActivation,
'ELU': ParametrizedActivation,
'PReLU': PReLU,
'MultiplierReLU': MultiplierReLU,
'Softmax': Softmax,
'TernaryTanh': TernaryTanh,
'HardActivation': HardActivation,
Expand Down
21 changes: 21 additions & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,27 @@ void prelu(data_T data[CONFIG_T::n_in], param_T alpha[CONFIG_T::n_in], res_T res
}
}

// *************************************************
// MultiplierReLU Activation
// *************************************************
template <class data_T, class multiplier_T, class res_T, typename CONFIG_T>
void multiplier_relu(data_T data[CONFIG_T::n_in], multiplier_T mul[1], res_T res[CONFIG_T::n_in]) {
#pragma HLS PIPELINE

data_T datareg;
for (int ii = 0; ii < CONFIG_T::n_in; ii++) {
datareg = data[ii];
if (datareg > 0) {

if (mul[0] >= 0)
res[ii] = datareg << mul[0];
else
res[ii] = datareg >> (-mul[0]);
} else
res[ii] = 0;
}
}

// *************************************************
// Binary TanH Activation
// *************************************************
Expand Down
Loading
Loading