@@ -120,8 +120,9 @@ def _(layer: Activation):
120
120
if fn_name == 'linear' :
121
121
return (requested_kif (layer ),)
122
122
if fn_name == 'relu' :
123
- k , i , f = requested_kif (layer )
124
- k = np .ones_like (k )
123
+ _ , _ , f = requested_kif (layer )
124
+ k = np .ones (f .shape , dtype = np .int8 )
125
+ i = np .full (f .shape , 126 , dtype = np .int8 )
125
126
return ((k , i , f ),)
126
127
inp_shape = get_input_shapes (layer )[0 ]
127
128
return (_maximum_kif_at_shape (inp_shape ),)
@@ -478,9 +479,6 @@ def default_register_precision(layer: Layer):
478
479
_ok , _oi , _of = np .minimum (_pk , _rk ), np .minimum (_pi , _ri ), np .minimum (_pf , _rf )
479
480
ok , oi , of = kif_arrs_to_ints ((_ok , _oi , _of ))
480
481
481
- if np .max (_pf ) > np .max (_rf ) and np .max (_pi ) <= np .max (_ri ):
482
- oi += 1 # Edge cases overflow prevention
483
-
484
482
result_t = to_hls4ml_fixed (ok , oi , of , f'{ layer .name } _t' )
485
483
layer .attributes .attributes ['result_t' ] = result_t
486
484
layer .get_output_variable ().type = result_t
@@ -495,7 +493,7 @@ def default_register_precision(layer: Layer):
495
493
496
494
# Set precision for fixed array (weight_t, bias_t, table_t, etc.)
497
495
for w_name_t , v in layer .attributes .attributes .items ():
498
- if not isinstance (v , NamedType ) and w_name_t .endswith ('_t' ):
496
+ if not isinstance (v , NamedType ) and not w_name_t .endswith ('_t' ):
499
497
continue # Not a precision, skip
500
498
501
499
w_name = w_name_t [:- 2 ]
@@ -531,6 +529,24 @@ def register_precision(node: Layer):
531
529
default_register_precision (node )
532
530
533
531
532
+ @register_precision .register
533
+ def _ (node : Activation ):
534
+ default_register_precision (node )
535
+ act_fn = node .attributes ['activation' ].lower ()
536
+ _k , _i , _f = get_input_kifs (node )[0 ]
537
+ k , i , f = kif_arrs_to_ints ((_k , _i , _f ))
538
+ table_size = int (2 ** (k + i + f ))
539
+
540
+ # Temporary workaround for sigmoid and tanh activations, which scale the input by constant factors
541
+ # TODO: Rewrite tanh and sigmoid fn templates
542
+ if act_fn == 'tanh' :
543
+ table_size = int (8 / 2.0 ** - f ) # LUT Range hardcoded to -4 ~ 4, match #fractional bits
544
+ elif act_fn == 'sigmoid' :
545
+ table_size = int (16 / 2.0 ** - f ) # LUT Range hardcoded to -8 ~ 8, match #fractional bits
546
+
547
+ node .attributes ['table_size' ] = table_size
548
+
549
+
534
550
@register_precision .register
535
551
def _ (node : Softmax ):
536
552
if not node .attributes .get ('_bit_exact' , False ):
0 commit comments