-
Notifications
You must be signed in to change notification settings - Fork 460
Support for multi graph build #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b870395
13fcf9d
04c171e
31b82f0
11481f3
9b2ddb1
862521f
c345aa4
3cc7d04
8f34ae1
acf4a42
090dedb
3656ec7
94eb6dc
22aabf1
d45b837
815b937
888078b
6415928
72871ce
ac191dd
5cd7f45
6050df2
9044eeb
fda9707
10d26ba
7263d7d
92bee58
a0fdb4f
bb9ec50
482c5b8
5acc175
b10881f
2bd5dc2
cd8c7aa
8d792af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
======================= | ||
MultiModelGraph Class | ||
======================= | ||
|
||
This page documents the ``MultiModelGraph`` class, which enables handling multiple subgraphs (each represented as a ``ModelGraph``) derived from a single original model. | ||
The central concept here is the division of a larger model into multiple smaller subgraphs at given layers which can be useful for: | ||
|
||
* Very large models | ||
* Step-wise optimization | ||
* Modular design flows | ||
|
||
A ``MultiModelGraph`` manages these subgraphs, facilitating: | ||
|
||
* Parallel building and synthesis | ||
* Stitched designs (merging the subgraphs in HW after synthesis) | ||
* Simulation and performance estimation of the stitched design | ||
|
||
-------------- | ||
Keras Example | ||
-------------- | ||
|
||
For example, when converting a Keras model, you can specify the layers at which to split the model directly: | ||
|
||
.. code-block:: python | ||
|
||
config = hls4ml.utils.config_from_keras_model(model, granularity='model') | ||
|
||
hls_model = hls4ml.converters.convert_from_keras_model( | ||
model, | ||
hls_config=config, | ||
backend='vitis', | ||
) | ||
hls_multigraph_model = hls4ml.model.to_multi_model_graph(hls_model, ['layer3', 'layer7']) | ||
|
||
Here, the ``hls_multigraph_model`` is a ``MultiModelGraph`` containing three subgraphs. Each subgraph is a ``ModelGraph`` accessible via indexing: ``hls_multigraph_model[i]``. | ||
|
||
|
||
---------------------------------- | ||
Key Methods for MultiModelGraph | ||
---------------------------------- | ||
|
||
* :ref:`compile <mmg-compile-method>` | ||
* :ref:`predict <mmg-predict-method>` | ||
* :ref:`build <mmg-build-method>` | ||
* :ref:`trace <mmg-trace-method>` | ||
|
||
---- | ||
|
||
.. _mmg-compile-method: | ||
|
||
``compile`` method | ||
================== | ||
|
||
Compiles all the individual ``ModelGraph`` subgraphs within the ``MultiModelGraph``. Also, compiles a chained bridge file with all the subgraphs linked together that can be used for the predict function. | ||
|
||
.. code-block:: python | ||
|
||
hls_multigraph_model.compile() | ||
|
||
---- | ||
|
||
.. _mmg-build-method: | ||
|
||
``build`` method | ||
================ | ||
|
||
Builds all subgraphs in parallel, each as if they were standalone ``ModelGraph`` projects. Returns reports for each subgraph. If configured, it then runs the stitching flow in Vivado, connecting the individual exported IPs and allowing you to simulate the stitched design at the RTL level. | ||
|
||
.. code-block:: python | ||
|
||
report = hls_multigraph_model.build(.., export=True, stitch_design=True, sim_stitched_design=True, export_stitched_design=True) | ||
|
||
The returned ``report`` contains results from each subgraph's build and, if stitching was performed, a combined report of the stitched design. Reports for individual ``ModelGraph`` instances are always accessible via | ||
``MultiModelGraph.graph_reports``. | ||
|
||
|
||
---- | ||
|
||
.. _mmg-predict-method: | ||
|
||
``predict`` method | ||
================== | ||
|
||
Performs a forward pass through the chained bridge file using the C-simulation (``sim='csim'``), providing 1-to-1 output with the original model. You can also leverage RTL simulation (``sim='rtl'``) to perform the forward pass at the register-transfer level. In this case, a Verilog testbench is dynamically generated and executed against the stitched IP design, providing behavioral simulation to accurately verify latency and output at the hardware level. Note that the input data for the RTL simulation must have a single batch dimension. | ||
|
||
.. code-block:: python | ||
|
||
# Perform prediction using C-simulation (default) | ||
y_csim = hls_multigraph_model.predict(X, sim='csim') | ||
|
||
# Perform prediction using RTL simulation (behavioral) | ||
y_rtl = hls_multigraph_model.predict(X, sim='rtl') | ||
|
||
|
||
|
||
-------------------------- | ||
Summary | ||
-------------------------- | ||
|
||
The ``MultiModelGraph`` class is a tool for modular hardware design. By splitting a large neural network into multiple subgraphs, building each independently, and then stitching them together, you gain flexibility, parallelism, and facilitate hierarchical design, incremental optimization, and integrated system-level simulations. | ||
|
||
|
||
Notes and Known Issues | ||
======================= | ||
|
||
Graph Splitting | ||
--------------- | ||
|
||
- Splitting in the middle of a branched architecture (e.g., ResNet skip connections) is currently unsupported. | ||
- Each split subgraph must have exactly one input. | ||
|
||
Multiple Inputs & Outputs | ||
------------------------- | ||
|
||
- The final NN output can support multiple output layers. | ||
- For networks with multiple input layers (a relatively uncommon case), proper synchronization is required in the testbench to drive inputs—especially for io_stream interfaces. | ||
|
||
Simulation Discrepancies | ||
------------------------ | ||
|
||
- Users should carefully verify functional equivalence (particularly for models that use ``io_stream`` interface) | ||
- These discrepancies are more noticeable with raw output logits; applying a softmax layer at the model output can often help mask these differences, but this should be used with caution. | ||
|
||
TODOs | ||
----------------------- | ||
|
||
- Currently tested with Vitis 2024.1. Investigate compatibility with other versions. | ||
- Add support for Verilator-based simulation to enable faster RTL simulation. | ||
- Investigate ``io_stream`` interface (output discrepancies, fifo optimization) | ||
- Investigate differences in resource utilization for the ``io_parallel`` interface. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,19 @@ | ||
import importlib.util | ||
import json | ||
import os | ||
import shutil | ||
import subprocess | ||
import sys | ||
|
||
from hls4ml.backends import VivadoBackend | ||
from hls4ml.model.flow import get_flow, register_flow | ||
from hls4ml.report import parse_vivado_report | ||
from hls4ml.report import aggregate_graph_reports, parse_vivado_report | ||
from hls4ml.utils.simulation_utils import ( | ||
annotate_axis_stream_widths, | ||
prepare_tb_inputs, | ||
read_testbench_log, | ||
write_verilog_testbench, | ||
) | ||
|
||
|
||
class VitisBackend(VivadoBackend): | ||
|
@@ -98,29 +108,131 @@ def build( | |
export=False, | ||
vsynth=False, | ||
fifo_opt=False, | ||
log_to_stdout=True, | ||
): | ||
if 'linux' in sys.platform: | ||
found = os.system('command -v vitis_hls > /dev/null') | ||
if found != 0: | ||
raise Exception('Vitis HLS installation not found. Make sure "vitis_hls" is on PATH.') | ||
|
||
curr_dir = os.getcwd() | ||
os.chdir(model.config.get_output_dir()) | ||
os.system( | ||
( | ||
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} ' | ||
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"' | ||
).format( | ||
reset=reset, | ||
csim=csim, | ||
synth=synth, | ||
cosim=cosim, | ||
validation=validation, | ||
export=export, | ||
vsynth=vsynth, | ||
fifo_opt=fifo_opt, | ||
) | ||
build_command = ( | ||
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} ' | ||
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"' | ||
).format( | ||
reset=reset, | ||
csim=csim, | ||
synth=synth, | ||
cosim=cosim, | ||
validation=validation, | ||
export=export, | ||
vsynth=vsynth, | ||
fifo_opt=fifo_opt, | ||
) | ||
os.chdir(curr_dir) | ||
|
||
return parse_vivado_report(model.config.get_output_dir()) | ||
output_dir = model.config.get_output_dir() | ||
stdout_log = os.path.join(output_dir, 'build_stdout.log') | ||
stderr_log = os.path.join(output_dir, 'build_stderr.log') | ||
|
||
stdout_target = None if log_to_stdout else open(stdout_log, 'w') | ||
stderr_target = None if log_to_stdout else open(stderr_log, 'w') | ||
|
||
try: | ||
process = subprocess.Popen( | ||
build_command, shell=True, cwd=output_dir, stdout=stdout_target, stderr=stderr_target, text=True | ||
) | ||
process.communicate() | ||
|
||
if process.returncode != 0: | ||
raise Exception(f'Build failed for {model.config.get_project_name()}. See logs for details.') | ||
finally: | ||
if not log_to_stdout: | ||
stdout_target.close() | ||
stderr_target.close() | ||
|
||
return parse_vivado_report(output_dir) | ||
|
||
def build_stitched_design( | ||
self, | ||
model, | ||
stitch_design=True, | ||
sim_stitched_design=False, | ||
export_stitched_design=False, | ||
graph_reports=None, | ||
simulation_input_data=None, | ||
): | ||
|
||
nn_config = model.nn_config | ||
os.makedirs(nn_config['OutputDir'], exist_ok=True) | ||
stitched_design_dir = os.path.join(nn_config['OutputDir'], nn_config['StitchedProjectName']) | ||
if stitch_design: | ||
if os.path.exists(stitched_design_dir): | ||
shutil.rmtree(stitched_design_dir) | ||
os.makedirs(stitched_design_dir) | ||
|
||
spec = importlib.util.find_spec('hls4ml') | ||
hls4ml_path = os.path.dirname(spec.origin) | ||
ip_stitcher_path = os.path.join(hls4ml_path, 'templates/vivado/ip_stitcher.tcl') | ||
stdout_log = os.path.join(stitched_design_dir, 'stitcher_stdout.log') | ||
stderr_log = os.path.join(stitched_design_dir, 'stitcher_stderr.log') | ||
nn_config_path = os.path.join(stitched_design_dir, 'nn_config.json') | ||
testbench_path = os.path.join(stitched_design_dir, 'testbench.v') | ||
testbench_log_path = os.path.join(stitched_design_dir, 'testbench_log.csv') | ||
|
||
try: | ||
shutil.copy(ip_stitcher_path, stitched_design_dir) | ||
except Exception as e: | ||
print(f"Error: {e}. Cannot copy 'ip_stitcher.tcl' to {nn_config['StitchedProjectName']} folder.") | ||
|
||
# Verilog output bitwidths are rounded up and may differ from HLS output bitwidths | ||
if nn_config['outputs'][0]['pragma'] == 'stream': | ||
last_graph_project_path = os.path.join( | ||
model.graphs[-1].config.get_output_dir(), model.graphs[-1].config.get_project_dir() | ||
) | ||
annotate_axis_stream_widths(nn_config, last_graph_project_path) | ||
with open(nn_config_path, "w") as file: | ||
json.dump(nn_config, file, indent=4) | ||
|
||
if sim_stitched_design: | ||
write_verilog_testbench(nn_config, testbench_path) | ||
tb_inputs = prepare_tb_inputs(simulation_input_data, nn_config['inputs']) | ||
model.write_tb_inputs(tb_inputs, stitched_design_dir) | ||
print('Verilog testbench and its input data were generated.') | ||
|
||
print('Running build process of stitched IP...\n') | ||
stitch_command = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest exporting this into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest exporting this to a .tcl somewhere (e.g., modify the top build.tcl) for offline build. For parallelized build, maybe use xargs or parallel in args, or just leave it to the user and only attach the script for stitching the multigraph design after individual builds. |
||
'vivado', | ||
'-mode', | ||
'batch', | ||
'-nojournal', | ||
'-nolog', | ||
'-notrace', | ||
'-source', | ||
ip_stitcher_path, | ||
'-tclargs', | ||
f'stitch_design={int(stitch_design)}', | ||
f'sim_design={int(sim_stitched_design)}', | ||
f'export_design={int(export_stitched_design)}', | ||
f"stitch_project_name={nn_config['StitchedProjectName']}", | ||
f"original_project_name={nn_config['OriginalProjectName']}", | ||
'sim_verilog_file=testbench.v', | ||
] | ||
|
||
with open(stdout_log, 'w') as stdout_file, open(stderr_log, 'w') as stderr_file: | ||
process = subprocess.Popen( | ||
stitch_command, cwd=stitched_design_dir, stdout=stdout_file, stderr=stderr_file, text=True, shell=False | ||
) | ||
process.communicate() | ||
if process.returncode != 0: | ||
raise Exception(f"Stitching failed for {nn_config['StitchedProjectName']}. See logs for details.") | ||
|
||
stitched_report = {'StitchedDesignReport': {}} | ||
if stitch_design: | ||
stitched_report = aggregate_graph_reports(graph_reports) | ||
|
||
if sim_stitched_design: | ||
testbench_output = read_testbench_log(testbench_log_path, nn_config['outputs']) | ||
stitched_report['BehavSimResults'] = testbench_output['BehavSimResults'] | ||
stitched_report['StitchedDesignReport']['BestLatency'] = testbench_output['BestLatency'] | ||
stitched_report['StitchedDesignReport']['WorstLatency'] = testbench_output['WorstLatency'] | ||
|
||
return stitched_report |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from hls4ml.model.graph import HLSConfig, ModelGraph # noqa: F401 | ||
from hls4ml.model.graph import HLSConfig, ModelGraph, MultiModelGraph, to_multi_model_graph # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the future: export a tcl script for offline build