Skip to content

Distributed Arithmetic strategy for Dense, Conv1/2D, and EinsumDense #1191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 65 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
c640c46
skip oneapi test if icpx doesnt exist
calad0i May 28, 2025
e95aabe
bit-exact-possible multidim softmax
calad0i Mar 8, 2025
0393f54
softmax fix
calad0i Mar 8, 2025
249206c
Merge branch 'skip_oneapi_test_if_icpx_doesnt_exist' into vivado-bit-…
calad0i Jun 4, 2025
a950636
move softmax attr to fpga backend, post rebase fix
calad0i Jun 4, 2025
904433d
add keras v3 object parser
calad0i Mar 7, 2025
a14e02e
add keras v3 layer handlers
calad0i Mar 7, 2025
a1bcb66
einsumdense and einsum
calad0i Mar 7, 2025
95b3e92
add einsum templates
calad0i Mar 8, 2025
3b7f1d6
bit-exact-possible multidim softmax
calad0i Mar 8, 2025
7037ea4
symbolic bitwidth infer util
calad0i Mar 8, 2025
a21e428
add qinterval test
calad0i Mar 8, 2025
f8a07ae
keras v2-v3 reshape fn compability patch
calad0i Mar 8, 2025
3ecc9ab
hgq2 layer handlers
calad0i Mar 8, 2025
d39cefa
add bit-exact enforcement pass
calad0i Mar 8, 2025
daede09
fix softmax accum fractional bits derivation
calad0i Mar 9, 2025
c316052
add qeinsum test
calad0i Mar 8, 2025
2195809
env update
calad0i Mar 8, 2025
d4a38d6
remove hgq v1 rewire behavier (superseded)
calad0i Mar 8, 2025
1896720
fix del method in config class
calad0i Mar 8, 2025
bfd8638
distributed arithmetic impl
calad0i Mar 8, 2025
27b1a03
distributed arithmetic impl w/ conv
calad0i Mar 8, 2025
c3f7085
distributed arithmetic templates
calad0i Mar 8, 2025
1c7f018
prevent pointwise override for DA strategy
calad0i Mar 8, 2025
03ad025
add test for DA
calad0i Mar 8, 2025
b454e55
update proj conf
calad0i Mar 8, 2025
bc8fc13
disable delay_constraint in da4ml
calad0i Mar 8, 2025
a0735ee
add hgq2 mha test
calad0i Mar 8, 2025
9a1a649
update ci template
calad0i Mar 8, 2025
558590f
require da4ml version
calad0i Mar 9, 2025
8554680
pre-commit fix
calad0i Mar 9, 2025
42d11c3
proper skip qeinsum test when condition not met
calad0i Mar 9, 2025
e2a6c1f
softmax and activation fix
calad0i Mar 15, 2025
218c607
hgq2 api change, prevent zero bw activation crash syn
calad0i Mar 15, 2025
f352024
qinterval and bn bit-exactness fix
calad0i Mar 18, 2025
854a886
fix einsum ax expansion and 0d output handling
calad0i Mar 18, 2025
fe3af42
fix merge templates
calad0i Mar 19, 2025
bc3377b
converter and bit-exact pass for ops layers
calad0i Mar 19, 2025
52bf0c1
use pointwise 2d for conv2d due to unknown flow changing
calad0i Mar 19, 2025
38c5eaf
fix einsum dense da impl corner case
calad0i Mar 20, 2025
c12aff5
qinterval type fix
calad0i Mar 20, 2025
60f151a
fix corner case in qkeras converted proxy
calad0i Apr 9, 2025
d68c452
support mha def in (q,v) format
calad0i Apr 22, 2025
dcc128c
update da4ml binding syntax
calad0i Apr 23, 2025
1ba6413
update da4ml binding syntax x2
calad0i Apr 24, 2025
38c6565
use fixedvararr obj for da codegen
calad0i May 5, 2025
bf98915
more general build_lib script
calad0i May 5, 2025
d36dff9
bring back hgq proxy embedded properties excl. pecision
calad0i May 6, 2025
ed00630
fix streaming conv1/2d da regression
calad0i May 10, 2025
2e337e1
streaming template support for DA fix
calad0i May 11, 2025
3c6271f
allow non-po-2 avg pooling
calad0i May 17, 2025
f680611
ignore batch dim in parse_data_format
calad0i May 17, 2025
abbf859
keras v3 native pooling layer parser
calad0i May 17, 2025
8380dcf
globalpooling handler fix
calad0i May 20, 2025
f0de421
unary lut bw derivation update
calad0i May 20, 2025
4579b1a
keras 3.10 api change
calad0i May 23, 2025
d5eaf8d
namespace fix for pointwise conv
calad0i May 23, 2025
b92871e
use constexpr for dim def
calad0i May 23, 2025
8e57245
conv pf handling
calad0i May 26, 2025
3b72c6a
keras v3 api change
calad0i May 27, 2025
6d22c62
quality-of-life changes
calad0i May 30, 2025
c8990ed
kv3 parser update
calad0i May 30, 2025
05d71b5
shut up!
calad0i May 30, 2025
0d450ae
post-rebase import conflicts
calad0i Jun 4, 2025
3728c0d
remaining post-rebase fix
calad0i Jun 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ repos:
exclude: docs/conf.py
additional_dependencies: [flake8-bugbear, flake8-print]
args: ['--max-line-length=125', # github viewer width
'--extend-ignore=E203,T201'] # E203 is not PEP8 compliant
'--extend-ignore=E203,T201', # E203 is not PEP8 compliant
'--per-file-ignores=hls4ml/model/optimizer/passes/bit_exact.py:E741',
# i for #int w/o sign, I for #int w/ sign when massively processing bw conversions
]

- repo: https://github.com/mgedmin/check-manifest
rev: "0.50"
Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pipeline {
sh '''#!/bin/bash --login
conda activate hls4ml-py310
conda install -y jupyterhub pydot graphviz pytest pytest-cov
pip install pytest-randomly jupyter onnx>=1.4.0 matplotlib pandas seaborn pydigitalwavetools==1.1 pyyaml tensorflow==2.14 qonnx torch git+https://github.com/jmitrevs/qkeras.git@qrecurrent_unstack pyparsing
pip install pytest-randomly jupyter onnx>=1.4.0 matplotlib pandas seaborn pydigitalwavetools==1.1 pyyaml tensorflow==2.14 qonnx torch git+https://github.com/jmitrevs/qkeras.git@qrecurrent_unstack pyparsing quantizers da4ml
pip install -U ../ --user
./convert-keras-models.sh -x -f keras-models.txt
pip uninstall hls4ml -y'''
Expand Down
33 changes: 19 additions & 14 deletions hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from hls4ml.backends.backend import Backend
from hls4ml.model.attributes import ChoiceAttribute, ConfigurableAttribute, TypeAttribute
from hls4ml.model.attributes import Attribute, ChoiceAttribute, ConfigurableAttribute, TypeAttribute
from hls4ml.model.layers import (
GRU,
LSTM,
Expand Down Expand Up @@ -109,32 +109,37 @@ def __init__(self, name):
act_attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type))
self.attribute_map[Activation] = act_attrs

softmax_attrs = self.attribute_map.get(Softmax, [])
softmax_attrs.append(
softmax_attrs = [
Attribute('n_in'),
Attribute('activation', value_type=str),
Attribute('n_outer', value_type=int, default=1),
Attribute('n_inner', value_type=int, default=1),
ChoiceAttribute(
'implementation',
['latency', 'stable', 'argmax', 'legacy'],
default='stable',
description=descriptions.softmax_implementation,
)
)
softmax_attrs.append(
ConfigurableAttribute('skip', value_type=bool, default=False, description=descriptions.softmax_skip)
)
softmax_attrs.append(
),
ConfigurableAttribute('skip', value_type=bool, default=False, description=descriptions.softmax_skip),
TypeAttribute(
'exp_table',
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
description=descriptions.table_type,
)
)
softmax_attrs.append(
),
TypeAttribute(
'inv_table',
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
description=descriptions.table_type,
)
)
),
TypeAttribute(
'inv_inp',
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
),
TypeAttribute(
'accum',
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
),
]
self.attribute_map[Softmax] = softmax_attrs

def create_layer_class(self, layer_class):
Expand Down
6 changes: 5 additions & 1 deletion hls4ml/backends/fpga/passes/fix_softmax_table_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

class FixSoftmaxTableSize(OptimizerPass):
def match(self, node):
return isinstance(node, Softmax)
if not isinstance(node, Softmax):
return False
if 'inv_table_size' in node.attributes:
return False # handler generating inv_table_size sets it properly
return True

def transform(self, model, node: Layer):
inp_layer = node.get_input_node() # type: ignore
Expand Down
5 changes: 0 additions & 5 deletions hls4ml/backends/fpga/passes/hgq_proxy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def match(self, node: Layer):
return isinstance(node, FixedPointQuantizer)

def transform(self, model, node: FixedPointQuantizer):
if node.fusible:
model.remove_node(node)
return True

if model.config.config['IOType'] != 'io_parallel':
raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel')

Expand Down Expand Up @@ -96,7 +92,6 @@ def __init__(self):

def format(self, node):
params = self._default_function_params(node)
node.attributes['result_t'].precision = node.attributes['table_t'].precision
params['config'] = f'unary_lut_config{node.index}'
params['table'] = node.get_weights('table').name

Expand Down
12 changes: 12 additions & 0 deletions hls4ml/backends/oneapi/oneapi_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def compile(self, model):
try:
subprocess.run('which icpx', shell=True, cwd=builddir, check=True)
except subprocess.CalledProcessError:
try:
import pytest

pytest.skip('icpx not present')
except ImportError:
pass
raise RuntimeError('Could not find icpx. Please configure oneAPI appropriately')
subprocess.run('cmake ..', shell=True, cwd=builddir, check=True)
subprocess.run('make lib', shell=True, cwd=builddir, check=True)
Expand All @@ -204,6 +210,12 @@ def build(self, model, build_type='fpga_emu', run=False):
try:
subprocess.run('which icpx', shell=True, cwd=builddir, check=True)
except subprocess.CalledProcessError:
try:
import pytest

pytest.skip('icpx not present')
except ImportError:
pass
raise RuntimeError('Could not find icpx. Please configure oneAPI appropriately')
subprocess.run('cmake ..', shell=True, cwd=builddir, check=True)
subprocess.run(f'make {build_type}', shell=True, cwd=builddir, check=True)
Expand Down
46 changes: 41 additions & 5 deletions hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
template<unsigned K, unsigned S, unsigned W>
using scale_index = nnet::{scale_index_type}<K, S, W>;
template<class data_T, class res_T, class CONFIG_T>
using conv_kernel = nnet::{conv_fn}<data_T, res_T, CONFIG_T>;
using conv_kernel = {conv_fn}<data_T, res_T, CONFIG_T>;
}};
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""

Expand Down Expand Up @@ -90,8 +90,8 @@ def format(self, node):
else:
params['scale_index_type'] = 'scale_index_regular'

namespace = params['namespace']
if node.model.config.get_config_value('IOType') == 'io_parallel':
namespace = params['namespace']
params['fill_fn'] = f'{namespace}::fill_buffer_{node.index}'
else:
params['fill_fn'] = 'nnet::FillConv1DBuffer'
Expand All @@ -102,12 +102,12 @@ def format(self, node):
and node.model.config.get_config_value('IOType') == 'io_parallel'
)
if is_pointwise_parallel_latency:
params['conv_fn'] = f'pointwise_conv_{node.index}'
params['conv_fn'] = f'{namespace}::pointwise_conv_{node.index}'
else:
if node.get_attr('strategy').lower() == 'latency':
params['conv_fn'] = 'Conv1DLatency'
params['conv_fn'] = 'nnet::Conv1DLatency'
else:
params['conv_fn'] = 'Conv1DResource'
params['conv_fn'] = 'nnet::Conv1DResource'

params['min_width'] = node.get_attr('min_width', node.get_attr('in_width'))
params['instructions'] = node.get_attr('instructions', '0')
Expand Down Expand Up @@ -154,11 +154,21 @@ def format(self, node):
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin'
elif node.get_attr('strategy').lower() == 'resource_unrolled':
mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
elif node.get_attr('strategy').lower() == 'distributed_arithmetic':
mult_params['dense_function'] = f'{namespace}::dense_da_wrapper_{node.index}'

mult_config = self.mult_template.format(**mult_params)

return mult_config + '\n' + conv_config

def match(self, node):
if node.get_attr('strategy') == 'distributed_arithmetic':
io_type = node.model.config.get_config_value("IOType")
if io_type == 'io_parallel':
# DA impl use alternate entry point for io_parallel conv
return False
return super().match(node)


class Conv1DFunctionTemplate(FunctionCallTemplate):
def __init__(self):
Expand All @@ -173,6 +183,14 @@ def format(self, node):

return self.template.format(**params)

def match(self, node):
if node.get_attr('strategy') == 'distributed_arithmetic':
io_type = node.model.config.get_config_value("IOType")
if io_type == 'io_parallel':
# DA impl use alternate entry point for io_parallel conv
return False
return super().match(node)


class DepthwiseConv1DFunctionTemplate(Conv1DFunctionTemplate):
def __init__(self):
Expand Down Expand Up @@ -299,11 +317,21 @@ def format(self, node):
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin'
elif node.get_attr('strategy').lower() == 'resource_unrolled':
mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
elif node.get_attr('strategy').lower() == 'distributed_arithmetic':
mult_params['dense_function'] = f'{namespace}::dense_da_wrapper_{node.index}'

mult_config = self.mult_template.format(**mult_params)

return mult_config + '\n' + conv_config

def match(self, node):
if node.get_attr('strategy') == 'distributed_arithmetic':
io_type = node.model.config.get_config_value("IOType")
if io_type == 'io_parallel':
# DA impl use alternate entry point for io_parallel conv
return False
return super().match(node)


class Conv2DFunctionTemplate(FunctionCallTemplate):
def __init__(self):
Expand All @@ -318,6 +346,14 @@ def format(self, node):

return self.template.format(**params)

def match(self, node):
if node.get_attr('strategy') == 'distributed_arithmetic':
io_type = node.model.config.get_config_value("IOType")
if io_type == 'io_parallel':
# DA impl use alternate entry point for io_parallel conv
return False
return super().match(node)


class DepthwiseConv2DFunctionTemplate(Conv2DFunctionTemplate):
def __init__(self):
Expand Down
84 changes: 82 additions & 2 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from math import ceil, log2

from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
Expand Down Expand Up @@ -55,9 +57,17 @@ def format(self, node):
# The 3rd case is never used
elif node.get_attr('strategy').lower() == 'resource_unrolled':
params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
elif node.get_attr('strategy').lower() == 'distributed_arithmetic':
# Only triggered in io_streaming mode
params['dense_function'] = f'{namespace}::dense_da_wrapper_{node.index}'

return self.template.format(**params)

def match(self, node):
if node.get_attr('strategy') == 'distributed_arithmetic':
return False # DA does not use common dense template
return super().match(node)


class DenseFunctionTemplate(FunctionCallTemplate):
def __init__(self):
Expand All @@ -71,6 +81,11 @@ def format(self, node):

return self.template.format(**params)

def match(self, node):
if node.get_attr('strategy') == 'distributed_arithmetic':
return False # DA does not use common dense template
return super().match(node)


# BatchNormalization templates

Expand Down Expand Up @@ -152,13 +167,22 @@ def format(self, node):

softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned n_slice = {n_slice};
static const unsigned n_outer = {n_outer};
static const unsigned n_inner = {n_inner};
static const unsigned parallelization_factor = {parallelization_factor};
static const unsigned exp_table_size = {exp_table_size};
static const unsigned inv_table_size = {inv_table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
static const unsigned axis = {axis};
static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation};
static constexpr float exp_scale = {exp_scale};
typedef {exp_table_t.name} exp_table_t;
typedef {inv_table_t.name} inv_table_t;
typedef {accum_t.name} accum_t;
typedef {inv_inp_t.name} inv_inp_t;
typedef {inp_norm_t_str} inp_norm_t;
}};\n"""

activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
Expand Down Expand Up @@ -210,10 +234,66 @@ def __init__(self):
super(ActivationConfigTemplate, self).__init__(Softmax) # Skip ActivationConfigTemplate's __init__
self.template = softmax_config_template

def format(self, node):
params = self._default_config_params(node)
params['type'] = node.get_attr('activation')
params.setdefault('exp_table_size', params['table_size'])
params.setdefault('inv_table_size', params['table_size'])
params.setdefault('n_inner', 1)
params.setdefault('n_outer', 1)
params.setdefault('exp_scale', 1.0)
params.setdefault('parallelization_factor', -1)

n_slice = params['n_in'] // params['n_inner'] // params['n_outer'] # type: ignore
params['n_slice'] = n_slice

if params['accum_t'].name == 'model_default_t': # type: ignore
scale = ceil(log2(n_slice))
exp_table_t = node.attributes['exp_table_t'].precision
signed, width, integers = exp_table_t.signed, exp_table_t.width, exp_table_t.integer
params['accum_t_str'] = f'ap_{"" if signed else "u"}fixed<{width + scale}, {integers + scale}>'
else:
params['accum_t_str'] = params['accum_t'].name # type: ignore
if params['inv_inp_t'].name == 'model_default_t': # type: ignore
params['inv_inp_t'] = params['exp_table_t']

if params['implementation'] == 'stable':
if 'inp_norm_t' not in params:
# Only used in stable (max-normalized) implementation
input_t = node.get_input_variable().type.precision
width, iwidth, signed = input_t.width, input_t.integer, input_t.signed # noqa: F841
width, iwidth = width - signed, iwidth - signed
if signed:
# Fix table size if too large
exp_table_size = params['inv_table_size']
params['exp_table_size'] = str(min(int(exp_table_size), 2**width))
params['inp_norm_t_str'] = f'ap_ufixed<{width}, {iwidth}>'
else:
params['inp_norm_t_str'] = params['inp_norm_t'].name # type: ignore
else:
params['inp_norm_t_str'] = 'ap_fixed<1,0>'

return self.template.format(**params)


class SoftmaxFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Softmax, include_header=activ_include_list)
self.template = activ_function_template

def format(self, node):
params = self._default_function_params(node)
use_multidim = node.get_attr('n_inner', 1) > 1 or node.get_attr('n_outer', 1) > 1
use_multidim = use_multidim and node.model.config.get_config_value('IOType') == 'io_parallel'
params['activation'] = 'softmax' if not use_multidim else 'softmax_multidim'
params['config'] = f'softmax_config{node.index}'

return self.template.format(**params)


class ActivationFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__((Activation, HardActivation, Softmax), include_header=activ_include_list)
super().__init__((Activation, HardActivation), include_header=activ_include_list)
self.template = activ_function_template

def format(self, node):
Expand Down
Loading
Loading