Skip to content

Simple PyTorch extension API #1247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def decorator(function):
return decorator


# map names of operations between toch.nn and torch.nn.functionals
# map names of operations between torch.nn and torch.nn.functionals
layer_name_map = {
'relu': 'ReLU',
'tanh': 'Tanh',
Expand Down Expand Up @@ -119,7 +119,8 @@ def parse_pytorch_model(config, verbose=True):
ModelGraph: hls4ml model object.
"""
import torch
from torch.fx import symbolic_trace

from hls4ml.utils.torch import CustomFXTracer

# This is a list of dictionaries to hold all the layer info we need to generate HLS
layer_list = []
Expand All @@ -136,11 +137,12 @@ def parse_pytorch_model(config, verbose=True):

model = reader.torch_model

# dict of layer objects in non-traced form for access lateron
# dict of layer objects in non-traced form for access later on
children = {c[0]: c[1] for c in model.named_children()}
# use symbolic_trace to get a full graph of the model

traced_model = symbolic_trace(model)
tracer = CustomFXTracer()
traced_model = tracer.trace(model)
# Define layers to skip for conversion to HLS
skip_layers = ['Dropout', 'Sequential']

Expand All @@ -167,21 +169,19 @@ def parse_pytorch_model(config, verbose=True):
# check for constant nodes
merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax']
i = 0 # count number of consts and use it in the name
for node in traced_model.graph.nodes:
for node in traced_model.nodes:
if node.name.split('_')[0] in merge_layers:
for arg in node.args:
if np.isscalar(arg):
# add an input node with the constant value
new_node = traced_model.graph.placeholder(
name='const_' + str(i), type_expr=torch.Tensor, default_value=arg
)
new_node = traced_model.placeholder(name='const_' + str(i), type_expr=torch.Tensor, default_value=arg)
node.prepend(new_node)
node.update_arg(1, new_node)
i += 1

traced_model.graph.lint()
traced_model.lint()

for node in traced_model.graph.nodes:
for node in traced_model.nodes:
if node.op == 'call_module':
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
# where x is an integer numbering the elements of the Sequential
Expand Down Expand Up @@ -238,7 +238,7 @@ def parse_pytorch_model(config, verbose=True):

# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
elif "getitem" in node.args[0].name:
for tmp_node in traced_model.graph.nodes:
for tmp_node in traced_model.nodes:
if tmp_node.name == node.args[0].name:
if "getitem" in tmp_node.args[0].name:
raise Exception('Nested getitem calles not resolved at the moment.')
Expand Down
26 changes: 26 additions & 0 deletions hls4ml/utils/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch


class HLS4MLModule(torch.nn.Module):
"""
Custom PyTorch module class for hls4ml to define custom modules that shouldn't be traced through by torch.FX
"""

pass


class CustomFXTracer(torch.fx.Tracer):

def is_leaf_module(self, m, module_qualified_name: str) -> bool:
"""
Custom Tracer class for hls4ml to define Brevitas modules and custom modules as leaf modules so they are not traced
through by torch.FX
"""
import torch

return (
isinstance(m, HLS4MLModule)
or m.__module__.startswith('torch.nn')
or m.__module__.startswith('torch.ao.nn')
or m.__module__.startswith('brevitas.nn')
) and not isinstance(m, torch.nn.Sequential)
192 changes: 192 additions & 0 deletions test/pytest/test_extensions_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from pathlib import Path

import numpy as np
import pytest
import torch

import hls4ml
import hls4ml.utils.torch

test_root_path = Path(__file__).parent


# PyTorch implementation of a custom layer
class TReverse(hls4ml.utils.torch.HLS4MLModule):
'''PyTorch implementation of a hypothetical custom layer'''

def __init__(self):
super().__init__()

def forward(self, inputs):
return torch.flip(inputs, dims=[-1])


# hls4ml layer implementation
class HReverse(hls4ml.model.layers.Layer):
'''hls4ml implementation of a hypothetical custom layer'''

def initialize(self):
inp = self.get_input_variable()
shape = inp.shape
dims = inp.dim_names
self.add_output_variable(shape, dims)


# hls4ml optimizer to remove duplicate optimizer
class RemoveDuplicateReverse(hls4ml.model.optimizer.OptimizerPass):
'''OptimizerPass to remove consecutive HReverse layers.'''

def match(self, node):
return isinstance(node, HReverse) and isinstance(node.get_input_node(), HReverse)

def transform(self, model, node):
first = node.get_input_node()
second = node

model.remove_node(first)
model.remove_node(second)
return True


# Parser for converter
def parse_reverse_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation == 'TReverse'

layer = {}
layer['class_name'] = 'HReverse'
layer['name'] = layer_name
layer['n_in'] = input_shapes[0][1]

if input_names is not None:
layer['inputs'] = input_names

return layer, [shape for shape in input_shapes[0]]


# HLS Templates - No specific pragmas used; generic enough for both Intel and Vivado

rev_config_template = """struct config{index} : nnet::reverse_config {{
static const unsigned n_in = {n_in};
}};\n"""

rev_function_template = 'nnet::reverse<{input_t}, {config}>({input}, {output});'
rev_include_list = ['nnet_utils/nnet_reverse.h']


class HReverseConfigTemplate(hls4ml.backends.template.LayerConfigTemplate):
def __init__(self):
super().__init__(HReverse)
self.template = rev_config_template

def format(self, node):
params = self._default_config_params(node)
return self.template.format(**params)


class HReverseFunctionTemplate(hls4ml.backends.template.FunctionCallTemplate):
def __init__(self):
super().__init__(HReverse, include_header=rev_include_list)
self.template = rev_function_template

def format(self, node):
params = self._default_function_params(node)
return self.template.format(**params)


rev_hls = """#ifndef NNET_REVERSE_H_
#define NNET_REVERSE_H_

#include "nnet_common.h"

namespace nnet {

struct reverse_config {
static const unsigned n_in = 10;
};

template<class data_T, typename CONFIG_T>
void reverse(
data_T input[CONFIG_T::n_in],
data_T reversed[CONFIG_T::n_in]
) {
for (int i = 0; i < CONFIG_T::n_in; i++) {
reversed[CONFIG_T::n_in - 1 - i] = input[i];
}
}

}

#endif
"""


@pytest.fixture(scope='session', autouse=True)
def register_custom_layer():
# Register the converter for custom PyTorch layer
hls4ml.converters.register_pytorch_layer_handler('TReverse', parse_reverse_layer)

# Register the hls4ml's IR layer
hls4ml.model.layers.register_layer('HReverse', HReverse)


@pytest.mark.parametrize('backend_id', ['Vivado', 'Vitis', 'Quartus'])
def test_extensions_pytorch(tmp_path, backend_id):
# Register the optimization passes (if any)
backend = hls4ml.backends.get_backend(backend_id)
ip_flow = hls4ml.model.flow.get_flow(backend.get_default_flow())
# Add the pass into the main optimization flow
optimize_flow = [flow for flow in ip_flow.requires if ':optimize' in flow][0]
optmizer_name = f'{backend_id.lower()}:remove_duplicate_reverse'
backend.register_pass(optmizer_name, RemoveDuplicateReverse, flow=optimize_flow)

# Register template passes for the given backend
backend.register_template(HReverseConfigTemplate)
backend.register_template(HReverseFunctionTemplate)

# Register HLS implementation
p = tmp_path / 'nnet_reverse.h'
p.write_text(rev_hls)
backend.register_source(p)

# Test if it works
class PyTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.reverse1 = TReverse()
self.relu = torch.nn.ReLU()
self.reverse2 = TReverse()
self.reverse3 = TReverse()

def forward(self, x):
x = self.reverse1(x)
x = self.relu(x)
x = self.reverse2(x)
x = self.reverse3(x)
return x

pmodel = PyTorchModel()

x = torch.randint(-5, 5, (8,), dtype=torch.int32)
pres = pmodel(x).detach().numpy()

config = hls4ml.utils.config_from_pytorch_model(
pmodel, (8,), default_precision='ap_int<6>', granularity='name', backend=backend_id
)
hmodel = hls4ml.converters.convert_from_pytorch_model(
pmodel,
output_dir=str(test_root_path / f'hls4mlprj_extensions_torch_{backend_id}'),
backend=backend_id,
io_type='io_parallel',
hls_config=config,
)

hmodel.compile()
hres = hmodel.predict(x.numpy().astype('float32'))

# Check if the optimizer pass was applied
assert optmizer_name in hmodel._applied_flows[0][optimize_flow]

# Remove flow from "optimize" step
hls4ml.model.flow.update_flow(optimize_flow, remove_optimizers=[optmizer_name])

np.testing.assert_array_equal(pres, hres)
Loading