Skip to content

Support for multi graph build #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b870395
split ModelGraph at specified layer name
dimdano Oct 11, 2024
13fcf9d
Add TCL script for automatic connection of subgraph IPs in Vivado
dimdano Oct 24, 2024
04c171e
initial support for stitched ip simulation
dimdano Dec 3, 2024
31b82f0
fix for multi input/output layers in graph splitting
dimdano Dec 19, 2024
11481f3
documentation for MultiModelGraph flow
dimdano Dec 20, 2024
9b2ddb1
faster rtl simulation
dimdano Jan 8, 2025
862521f
unwrap list if it has single element
dimdano Jan 10, 2025
c345aa4
Make MultiModelGraph adaptable to user-defined names
dimdano Jan 15, 2025
3cc7d04
stitch script time verbose
dimdano Jan 15, 2025
8f34ae1
fix with existing stitch project folder
dimdano Jan 15, 2025
acf4a42
initial support for multigraph compilation in bridge file
dimdano Jan 16, 2025
090dedb
stitched report fix for VivadoSynth aggregate
dimdano Jan 17, 2025
3656ec7
use log_to_stdout flag for parallel builds
dimdano Jan 21, 2025
94eb6dc
small change
dimdano Jan 24, 2025
22aabf1
remove bridged multigraph compilation for now
dimdano Jan 24, 2025
d45b837
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jan 24, 2025
815b937
fix 'ap_rst' port polarity for active high case
dimdano Jan 28, 2025
888078b
support for partition interface in verilog testbench
dimdano Jan 29, 2025
6415928
support for MultiModelGraph predict using chained bridge file
dimdano Feb 14, 2025
72871ce
Add pytest for multi-graph and fix minor issues
dimdano Mar 3, 2025
ac191dd
pre-commit fixes
dimdano Mar 4, 2025
5cd7f45
removed pandas dependency in read_testbench_log
dimdano Mar 10, 2025
6050df2
Ensure stitched RTL simulation results align with CSim output
dimdano Mar 14, 2025
9044eeb
parallel subgraph compilation
dimdano Apr 16, 2025
fda9707
added additional checks in ip_stitcher
dimdano Apr 16, 2025
10d26ba
small improvements on MultiModelGraph
dimdano Apr 16, 2025
7263d7d
correct AXIS port slicing for Verilog simulation
dimdano Apr 30, 2025
92bee58
Generate Verilog testbench inputs using C++ bridge
dimdano May 14, 2025
a0fdb4f
Fix rebase conflict in ModelGraph object creation
dimdano May 19, 2025
bb9ec50
Major rewrite of multi-graph splitting. it now uses the optimized Mod…
dimdano May 23, 2025
482c5b8
minor fixes and improvements
dimdano May 26, 2025
5acc175
minor fixes
dimdano May 27, 2025
b10881f
skip stitching if a graph failed
dimdano May 28, 2025
2bd5dc2
Merge branch 'main' into make_multi_graph
dimdano Jun 26, 2025
cd8c7aa
final changes
dimdano Jun 26, 2025
8d792af
remove synthesis test
dimdano Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/img/logo_small.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
130 changes: 130 additions & 0 deletions docs/ir/multimodelgraph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
=======================
MultiModelGraph Class
=======================

This page documents the ``MultiModelGraph`` class, which enables handling multiple subgraphs (each represented as a ``ModelGraph``) derived from a single original model.
The central concept here is the division of a larger model into multiple smaller subgraphs at given layers which can be useful for:

* Very large models
* Step-wise optimization
* Modular design flows

A ``MultiModelGraph`` manages these subgraphs, facilitating:

* Parallel building and synthesis
* Stitched designs (merging the subgraphs in HW after synthesis)
* Simulation and performance estimation of the stitched design

--------------
Keras Example
--------------

For example, when converting a Keras model, you can specify the layers at which to split the model directly:

.. code-block:: python

config = hls4ml.utils.config_from_keras_model(model, granularity='model')

hls_model = hls4ml.converters.convert_from_keras_model(
model,
hls_config=config,
backend='vitis',
)
hls_multigraph_model = hls4ml.model.to_multi_model_graph(hls_model, ['layer3', 'layer7'])

Here, the ``hls_multigraph_model`` is a ``MultiModelGraph`` containing three subgraphs. Each subgraph is a ``ModelGraph`` accessible via indexing: ``hls_multigraph_model[i]``.


----------------------------------
Key Methods for MultiModelGraph
----------------------------------

* :ref:`compile <mmg-compile-method>`
* :ref:`predict <mmg-predict-method>`
* :ref:`build <mmg-build-method>`
* :ref:`trace <mmg-trace-method>`

----

.. _mmg-compile-method:

``compile`` method
==================

Compiles all the individual ``ModelGraph`` subgraphs within the ``MultiModelGraph``. Also, compiles a chained bridge file with all the subgraphs linked together that can be used for the predict function.

.. code-block:: python

hls_multigraph_model.compile()

----

.. _mmg-build-method:

``build`` method
================

Builds all subgraphs in parallel, each as if they were standalone ``ModelGraph`` projects. Returns reports for each subgraph. If configured, it then runs the stitching flow in Vivado, connecting the individual exported IPs and allowing you to simulate the stitched design at the RTL level.

.. code-block:: python

report = hls_multigraph_model.build(.., export=True, stitch_design=True, sim_stitched_design=True, export_stitched_design=True)

The returned ``report`` contains results from each subgraph's build and, if stitching was performed, a combined report of the stitched design. Reports for individual ``ModelGraph`` instances are always accessible via
``MultiModelGraph.graph_reports``.


----

.. _mmg-predict-method:

``predict`` method
==================

Performs a forward pass through the chained bridge file using the C-simulation (``sim='csim'``), providing 1-to-1 output with the original model. You can also leverage RTL simulation (``sim='rtl'``) to perform the forward pass at the register-transfer level. In this case, a Verilog testbench is dynamically generated and executed against the stitched IP design, providing behavioral simulation to accurately verify latency and output at the hardware level. Note that the input data for the RTL simulation must have a single batch dimension.

.. code-block:: python

# Perform prediction using C-simulation (default)
y_csim = hls_multigraph_model.predict(X, sim='csim')

# Perform prediction using RTL simulation (behavioral)
y_rtl = hls_multigraph_model.predict(X, sim='rtl')



--------------------------
Summary
--------------------------

The ``MultiModelGraph`` class is a tool for modular hardware design. By splitting a large neural network into multiple subgraphs, building each independently, and then stitching them together, you gain flexibility, parallelism, and facilitate hierarchical design, incremental optimization, and integrated system-level simulations.


Notes and Known Issues
=======================

Graph Splitting
---------------

- Splitting in the middle of a branched architecture (e.g., ResNet skip connections) is currently unsupported.
- Each split subgraph must have exactly one input.

Multiple Inputs & Outputs
-------------------------

- The final NN output can support multiple output layers.
- For networks with multiple input layers (a relatively uncommon case), proper synchronization is required in the testbench to drive inputs—especially for io_stream interfaces.

Simulation Discrepancies
------------------------

- Users should carefully verify functional equivalence (particularly for models that use ``io_stream`` interface)
- These discrepancies are more noticeable with raw output logits; applying a softmax layer at the model output can often help mask these differences, but this should be used with caution.

TODOs
-----------------------

- Currently tested with Vitis 2024.1. Investigate compatibility with other versions.
- Add support for Verilator-based simulation to enable faster RTL simulation.
- Investigate ``io_stream`` interface (output discrepancies, fifo optimization)
- Investigate differences in resource utilization for the ``io_parallel`` interface.
150 changes: 131 additions & 19 deletions hls4ml/backends/vitis/vitis_backend.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future: export a tcl script for offline build

Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import importlib.util
import json
import os
import shutil
import subprocess
import sys

from hls4ml.backends import VivadoBackend
from hls4ml.model.flow import get_flow, register_flow
from hls4ml.report import parse_vivado_report
from hls4ml.report import aggregate_graph_reports, parse_vivado_report
from hls4ml.utils.simulation_utils import (
annotate_axis_stream_widths,
prepare_tb_inputs,
read_testbench_log,
write_verilog_testbench,
)


class VitisBackend(VivadoBackend):
Expand Down Expand Up @@ -98,29 +108,131 @@ def build(
export=False,
vsynth=False,
fifo_opt=False,
log_to_stdout=True,
):
if 'linux' in sys.platform:
found = os.system('command -v vitis_hls > /dev/null')
if found != 0:
raise Exception('Vitis HLS installation not found. Make sure "vitis_hls" is on PATH.')

curr_dir = os.getcwd()
os.chdir(model.config.get_output_dir())
os.system(
(
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
).format(
reset=reset,
csim=csim,
synth=synth,
cosim=cosim,
validation=validation,
export=export,
vsynth=vsynth,
fifo_opt=fifo_opt,
)
build_command = (
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
).format(
reset=reset,
csim=csim,
synth=synth,
cosim=cosim,
validation=validation,
export=export,
vsynth=vsynth,
fifo_opt=fifo_opt,
)
os.chdir(curr_dir)

return parse_vivado_report(model.config.get_output_dir())
output_dir = model.config.get_output_dir()
stdout_log = os.path.join(output_dir, 'build_stdout.log')
stderr_log = os.path.join(output_dir, 'build_stderr.log')

stdout_target = None if log_to_stdout else open(stdout_log, 'w')
stderr_target = None if log_to_stdout else open(stderr_log, 'w')

try:
process = subprocess.Popen(
build_command, shell=True, cwd=output_dir, stdout=stdout_target, stderr=stderr_target, text=True
)
process.communicate()

if process.returncode != 0:
raise Exception(f'Build failed for {model.config.get_project_name()}. See logs for details.')
finally:
if not log_to_stdout:
stdout_target.close()
stderr_target.close()

return parse_vivado_report(output_dir)

def build_stitched_design(
self,
model,
stitch_design=True,
sim_stitched_design=False,
export_stitched_design=False,
graph_reports=None,
simulation_input_data=None,
):

nn_config = model.nn_config
os.makedirs(nn_config['OutputDir'], exist_ok=True)
stitched_design_dir = os.path.join(nn_config['OutputDir'], nn_config['StitchedProjectName'])
if stitch_design:
if os.path.exists(stitched_design_dir):
shutil.rmtree(stitched_design_dir)
os.makedirs(stitched_design_dir)

spec = importlib.util.find_spec('hls4ml')
hls4ml_path = os.path.dirname(spec.origin)
ip_stitcher_path = os.path.join(hls4ml_path, 'templates/vivado/ip_stitcher.tcl')
stdout_log = os.path.join(stitched_design_dir, 'stitcher_stdout.log')
stderr_log = os.path.join(stitched_design_dir, 'stitcher_stderr.log')
nn_config_path = os.path.join(stitched_design_dir, 'nn_config.json')
testbench_path = os.path.join(stitched_design_dir, 'testbench.v')
testbench_log_path = os.path.join(stitched_design_dir, 'testbench_log.csv')

try:
shutil.copy(ip_stitcher_path, stitched_design_dir)
except Exception as e:
print(f"Error: {e}. Cannot copy 'ip_stitcher.tcl' to {nn_config['StitchedProjectName']} folder.")

# Verilog output bitwidths are rounded up and may differ from HLS output bitwidths
if nn_config['outputs'][0]['pragma'] == 'stream':
last_graph_project_path = os.path.join(
model.graphs[-1].config.get_output_dir(), model.graphs[-1].config.get_project_dir()
)
annotate_axis_stream_widths(nn_config, last_graph_project_path)
with open(nn_config_path, "w") as file:
json.dump(nn_config, file, indent=4)

if sim_stitched_design:
write_verilog_testbench(nn_config, testbench_path)
tb_inputs = prepare_tb_inputs(simulation_input_data, nn_config['inputs'])
model.write_tb_inputs(tb_inputs, stitched_design_dir)
print('Verilog testbench and its input data were generated.')

print('Running build process of stitched IP...\n')
stitch_command = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest exporting this into the build_prj.tcl and invoke it from these, as having hls4ml creating the model and put them on another machine for HLS/logic could be a common workflow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stitch_command is relatively fast and only runs after all the individual subgraph builds are complete. However, since hls4ml manages these builds in parallel using a Python thread pool, supporting this workflow on a remote server would require a Python script that mimics this behavior, so essentially looping over each subgraph directory and running its corresponding build_prj.tcl in parallel using threads or processes. It's not hard to set up and I will do it once we finalize the flow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest exporting this to a .tcl somewhere (e.g., modify the top build.tcl) for offline build. For parallelized build, maybe use xargs or parallel in args, or just leave it to the user and only attach the script for stitching the multigraph design after individual builds.

'vivado',
'-mode',
'batch',
'-nojournal',
'-nolog',
'-notrace',
'-source',
ip_stitcher_path,
'-tclargs',
f'stitch_design={int(stitch_design)}',
f'sim_design={int(sim_stitched_design)}',
f'export_design={int(export_stitched_design)}',
f"stitch_project_name={nn_config['StitchedProjectName']}",
f"original_project_name={nn_config['OriginalProjectName']}",
'sim_verilog_file=testbench.v',
]

with open(stdout_log, 'w') as stdout_file, open(stderr_log, 'w') as stderr_file:
process = subprocess.Popen(
stitch_command, cwd=stitched_design_dir, stdout=stdout_file, stderr=stderr_file, text=True, shell=False
)
process.communicate()
if process.returncode != 0:
raise Exception(f"Stitching failed for {nn_config['StitchedProjectName']}. See logs for details.")

stitched_report = {'StitchedDesignReport': {}}
if stitch_design:
stitched_report = aggregate_graph_reports(graph_reports)

if sim_stitched_design:
testbench_output = read_testbench_log(testbench_log_path, nn_config['outputs'])
stitched_report['BehavSimResults'] = testbench_output['BehavSimResults']
stitched_report['StitchedDesignReport']['BestLatency'] = testbench_output['BestLatency']
stitched_report['StitchedDesignReport']['WorstLatency'] = testbench_output['WorstLatency']

return stitched_report
1 change: 1 addition & 0 deletions hls4ml/backends/vivado/passes/transform_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def transform(self, model, node):
new_var = self.array_var_converter.convert(var, pragma='stream')
elif io_type == 'io_parallel':
if out_name in node.model.inputs:
# NOTE this needs to be changed to partition
new_var = self.array_var_converter.convert(var, pragma='reshape')
elif isinstance(var, InplaceTensorVariable):
new_var = self.inplace_array_var_converter.convert(var, pragma='')
Expand Down
4 changes: 1 addition & 3 deletions hls4ml/converters/keras_v2_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,4 @@ def parse_keras_model(model_arch, reader):
def keras_v2_to_hls(config):
model_arch, reader = get_model_arch(config)
layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader)
print('Creating HLS model')
hls_model = ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)
return hls_model
return ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)
4 changes: 1 addition & 3 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,4 @@ def parse_pytorch_model(config, verbose=True):
@requires('_torch')
def pytorch_to_hls(config):
layer_list, input_layers, output_layers = parse_pytorch_model(config)
print('Creating HLS model')
hls_model = ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers)
return hls_model
return ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers)
2 changes: 1 addition & 1 deletion hls4ml/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from hls4ml.model.graph import HLSConfig, ModelGraph # noqa: F401
from hls4ml.model.graph import HLSConfig, ModelGraph, MultiModelGraph, to_multi_model_graph # noqa: F401
Loading