Skip to content

Commit 93507e8

Browse files
committed
fixes in bw inference
1 parent 7f19a1a commit 93507e8

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def _(layer: Activation):
120120
if fn_name == 'linear':
121121
return (requested_kif(layer),)
122122
if fn_name == 'relu':
123-
k, i, f = requested_kif(layer)
124-
k = np.ones_like(k)
123+
_, _, f = requested_kif(layer)
124+
k = np.ones(f.shape, dtype=np.int8)
125+
i = np.full(f.shape, 126, dtype=np.int8)
125126
return ((k, i, f),)
126127
inp_shape = get_input_shapes(layer)[0]
127128
return (_maximum_kif_at_shape(inp_shape),)
@@ -478,9 +479,6 @@ def default_register_precision(layer: Layer):
478479
_ok, _oi, _of = np.minimum(_pk, _rk), np.minimum(_pi, _ri), np.minimum(_pf, _rf)
479480
ok, oi, of = kif_arrs_to_ints((_ok, _oi, _of))
480481

481-
if np.max(_pf) > np.max(_rf) and np.max(_pi) <= np.max(_ri):
482-
oi += 1 # Edge cases overflow prevention
483-
484482
result_t = to_hls4ml_fixed(ok, oi, of, f'{layer.name}_t')
485483
layer.attributes.attributes['result_t'] = result_t
486484
layer.get_output_variable().type = result_t
@@ -495,7 +493,7 @@ def default_register_precision(layer: Layer):
495493

496494
# Set precision for fixed array (weight_t, bias_t, table_t, etc.)
497495
for w_name_t, v in layer.attributes.attributes.items():
498-
if not isinstance(v, NamedType) and w_name_t.endswith('_t'):
496+
if not isinstance(v, NamedType) and not w_name_t.endswith('_t'):
499497
continue # Not a precision, skip
500498

501499
w_name = w_name_t[:-2]
@@ -531,6 +529,24 @@ def register_precision(node: Layer):
531529
default_register_precision(node)
532530

533531

532+
@register_precision.register
533+
def _(node: Activation):
534+
default_register_precision(node)
535+
act_fn = node.attributes['activation'].lower()
536+
_k, _i, _f = get_input_kifs(node)[0]
537+
k, i, f = kif_arrs_to_ints((_k, _i, _f))
538+
table_size = int(2 ** (k + i + f))
539+
540+
# Temporary workaround for sigmoid and tanh activations, which scale the input by constant factors
541+
# TODO: Rewrite tanh and sigmoid fn templates
542+
if act_fn == 'tanh':
543+
table_size = int(8 / 2.0**-f) # LUT Range hardcoded to -4 ~ 4, match #fractional bits
544+
elif act_fn == 'sigmoid':
545+
table_size = int(16 / 2.0**-f) # LUT Range hardcoded to -8 ~ 8, match #fractional bits
546+
547+
node.attributes['table_size'] = table_size
548+
549+
534550
@register_precision.register
535551
def _(node: Softmax):
536552
if not node.attributes.get('_bit_exact', False):

0 commit comments

Comments
 (0)