Skip to content

Commit d63033b

Browse files
authored
Fix tanh activiation in pytorch parser (#1055)
* fix tanh activiation in pytorch parser * simplify fix but making the activation attribute lower case
1 parent c8c95a7 commit d63033b

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

hls4ml/converters/pytorch/core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,12 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
4343
layer = {}
4444

4545
layer['class_name'] = operation
46-
layer['activation'] = layer['class_name']
46+
layer['activation'] = layer['class_name'].lower()
4747
layer['name'] = layer_name
4848
layer['inputs'] = input_names
4949

50-
# if layer['class_name'] != 'Activation':
51-
# layer['activation'] = layer['class_name']
5250
if node.op == 'call_module':
53-
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
51+
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
5452
layer['class_name'] = 'Activation'
5553
if layer['class_name'] == 'LeakyReLU':
5654
layer['activ_param'] = class_object.negative_slope
@@ -68,7 +66,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
6866
if hasattr(node, 'dim'):
6967
layer['axis'] = class_object.dim
7068
else:
71-
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
69+
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
7270
layer['class_name'] = 'Activation'
7371
if layer['class_name'] == 'LeakyReLU':
7472
layer['activ_param'] = node.kwargs['negative_slope']

hls4ml/converters/pytorch_to_hls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def decorator(function):
8484
# map names of operations between toch.nn and torch.nn.functionals
8585
layer_name_map = {
8686
'relu': 'ReLU',
87+
'tanh': 'Tanh',
8788
'leaky_relu': 'LeakyReLU',
8889
'elu': 'ELU',
8990
'prelu': 'PReLU',

test/pytest/test_pytorch_api.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def test_linear(backend, io_type):
6464
"activation_function",
6565
[
6666
nn.ReLU(),
67+
nn.Tanh(),
6768
nn.LeakyReLU(negative_slope=1.0),
6869
nn.ELU(alpha=1.0),
6970
nn.PReLU(init=0.25),
@@ -102,7 +103,7 @@ def test_activations(activation_function, backend, io_type):
102103

103104
assert nNodes - 1 == len(hls_model.get_layers())
104105

105-
if activation_function.__class__.__name__ == 'ReLU' or activation_function.__class__.__name__ == 'Sigmoid':
106+
if activation_function.__class__.__name__ in ['ReLU', 'Sigmoid', 'Tanh']:
106107
assert list(hls_model.get_layers())[2].attributes['class_name'] == 'Activation'
107108
elif activation_function.__class__.__name__ == 'Threshold':
108109
assert list(hls_model.get_layers())[2].attributes['class_name'] == 'ThresholdedReLU'
@@ -118,6 +119,14 @@ def forward(self, x):
118119
return nn.functional.relu(x)
119120

120121

122+
class TanHModel(nn.Module):
123+
def __init__(self):
124+
super().__init__()
125+
126+
def forward(self, x):
127+
return nn.functional.tanh(x)
128+
129+
121130
class LeakyReLuModel(nn.Module):
122131
def __init__(self):
123132
super().__init__()
@@ -154,6 +163,7 @@ def forward(self, x):
154163
"activation_function",
155164
[
156165
ReLuModel(),
166+
TanHModel(),
157167
LeakyReLuModel(),
158168
EluModel(),
159169
SigmoidModel(),
@@ -172,7 +182,7 @@ def test_activation_functionals(activation_function, backend, io_type):
172182

173183
config = config_from_pytorch_model(model, (1,))
174184
fn_name = activation_function.__class__.__name__
175-
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_relu_{backend}_{io_type}_{fn_name}')
185+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_{fn_name}_{backend}_{io_type}')
176186
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
177187
hls_model.compile()
178188

@@ -268,7 +278,7 @@ def test_conv1d(padds, backend, io_type):
268278
act_index = 2
269279
assert list(hls_model.get_layers())[conv_index].attributes['name'] == convNode.name
270280
assert list(hls_model.get_layers())[conv_index].attributes['class_name'] == 'Conv1D'
271-
assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__
281+
assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__.lower()
272282
if io_type == "io_stream" and (backend == "Vivado" or backend == "Vitis") and padds == 1:
273283
assert list(hls_model.get_layers())[conv_index].attributes["in_width"] == size_in + 2
274284
else:
@@ -412,7 +422,9 @@ def test_conv2d(padds, backend, io_type):
412422
act_index = 2
413423
assert list(hls_model.get_layers())[conv_index].attributes['name'] == convNode.name
414424
assert list(hls_model.get_layers())[conv_index].attributes['class_name'] == 'Conv2D'
415-
assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__
425+
assert (
426+
list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__.lower()
427+
)
416428
assert list(hls_model.get_layers())[conv_index].attributes["in_width"] == size_in_width
417429
assert list(hls_model.get_layers())[conv_index].attributes["in_height"] == size_in_height
418430
assert list(hls_model.get_layers())[conv_index].attributes['filt_width'] == class_object_conv.kernel_size[1]

0 commit comments

Comments
 (0)