Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@
[submodule "hls4ml/templates/catapult/ac_math"]
path = hls4ml/templates/catapult/ac_math
url = https://github.com/hlslibs/ac_math.git
[submodule "hls4ml/contrib/Coyote"]
path = hls4ml/contrib/Coyote
url = https://github.com/fpgasystems/Coyote.git
branch = integrations/hls4ml
67 changes: 67 additions & 0 deletions docs/backend/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,70 @@ The ``predict`` method will send the input data to the PL and return the output

nn = NeuralNetworkOverlay('hls4ml_nn.bit', X_test.shape, y_test.shape)
y_hw, latency, throughput = nn.predict(X_test, profile=True)


=================
CoyoteAccelerator
=================

The **CoyoteAccelerator** backend of ``hls4ml`` leverages the `Coyote shell <https://github.com/fpgasystems/Coyote>`_ to easily deploy models on PCIe-attached Alveo FPGAs.
Coyote is an open-source, research shell that facilitates the deployment of applications on FPGAs, as well as the integration of FPGAs into larger computer systems.
Some of its features include:
- Multi-tenancy
- Virtualized memory
- Optimized data movement
- Dynamic reconfiguration
- Automatic work scheduling and memory striping
- Networking for distributed applications

The list of supported boards is available in the `Coyote documentation. <https://fpgasystems.github.io/Coyote/intro/quick-start.html>`_
The current Coyote backend can be used to deploy hls4ml models from both Python and C++. While the focus of the current backend is on the inference,
it can easily be extended to support dynamic reconfiguration of models, as well as distributed inference across multiple FPGAs.

CoyoteOverlay
================================

Similar to the VivadoAccelerator backend, the Coyote backend creates a custom **neural network overlay** that interacts with the FPGA.
This overlay can be used to provide inputs, run inference and retrieve the predictions. Additionally, the overlay provides a utility
functon to load the model bitstream and driver for some clusters. On others, the users need to manually load the bitstream and driver.
For guidance, see the `Coyote documentation. <https://fpgasystems.github.io/Coyote/intro/quick-start.html#deploying-coyote>`_.

C++ binary
================================

Additionally, the Coyote backend generates and compiles a C++ program that can be used to run inference on the FPGA.
The binary can be found in ``<hls4ml-output-dir>/build/<project-name>_cyt_sw/bin/test`` and when launched, it will
run inference using the inputs from ``tb_data``. Similar to the Python overlay, the bitstream and driver must be loaded before running the inference.

Example
======================

Similar to the ``VivadoAccelerator``backend, we first generate a bitstream from a Keras model ``model`` and a config.

.. code-block:: Python

import hls4ml
config = hls4ml.utils.config_from_keras_model(model, granularity='name')
hls_model = hls4ml.converters.convert_from_keras_model(model,
hls_config=config,
output_dir='hls4ml_prj_coyote',
backend='CoyoteAccelerator',
board='u55c')
hls4ml.build(bitfile=True)

After this command completes, the FPGA must be programmed with the bistream. Additionally, the Coyote driver must be loaded.
For some platforms, Coyote provides utility functions to load the bitstream and driver. For others, this can be achieved using
the Vivado hardware manager and Linux commands. More detail can be found in the `Coyote documentation. <https://fpgasystems.github.io/Coyote/intro/quick-start.html#deploying-coyote>`_.

Finally, we can create a ``CoyoteOverlay`` object, which can be used to run inference on the FPGA. Additionally, the overlay provides a utility
functon to load the model bitstream and driver for some clusters.
When running inference, we must provide the input tensor and the shape of the output tensor (to allocate the buffers for the data transfer).
Optionally, batch size can be specified..
The ``predict`` method will send the input data to the FPGA and return the output data ``y_hw``.

.. code-block:: Python

from hls4ml.backends.coyote_accelerator.coyote_accelerator_overlay import CoyoteOverlay

overlay = CoyoteOverlay('hls4ml_prj_coyote')
y_hw = overlay.predict(x, (1, ), BATCH_SIZE)
3 changes: 3 additions & 0 deletions hls4ml/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

from hls4ml.backends.vitis.vitis_backend import VitisBackend # isort: skip

from hls4ml.backends.coyote_accelerator.coyote_accelerator_backend import CoyoteAcceleratorBackend

register_backend('Vivado', VivadoBackend)
register_backend('VivadoAccelerator', VivadoAcceleratorBackend)
register_backend('Vitis', VitisBackend)
register_backend('Quartus', QuartusBackend)
register_backend('Catapult', CatapultBackend)
register_backend('SymbolicExpression', SymbolicExpressionBackend)
register_backend('oneAPI', OneAPIBackend)
register_backend('CoyoteAccelerator', CoyoteAcceleratorBackend)
Empty file.
150 changes: 150 additions & 0 deletions hls4ml/backends/coyote_accelerator/coyote_accelerator_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import subprocess
from hls4ml.model.flow import get_flow, register_flow
from hls4ml.backends import VitisBackend, VivadoBackend

class CoyoteAcceleratorBackend(VitisBackend):
"""
The CoyoteAccelerator backend, which deploys hls4ml models on a PCIe-attached Alveo FPGA
Underneath it uses the Coyote shell: https://github.com/fpgasystems/Coyote,
which offers high-performance data movement, networking capabilities, multi-tenancy,
partial reconfiguration etc. This backend has some similarities with the VitisAccelerator
backend, but the underlying platforms are different. The implementation of this backend
remains mostly simple, inheriting most of the functionality from the Vitis backend and
providing the necessary infrastructure to run model inference on Alveo boards.

Currently, this backend supports batched inference of a single model on hardware.
In the future, it can easily be extended with the following capabilities, leveraging
Coyote's features:
- Distributed inference
- Multiple parallel instances of hls4ml models (same or distinct models)
- Dynamic, run-time reconfiguration of models

Generic examples of Coyote can be found at the above-mentioned repository, under examples/
"""

def __init__(self):
super(VivadoBackend, self).__init__(name='CoyoteAccelerator')
self._register_layer_attributes()
self._register_flows()

def _register_flows(self):
writer_passes = ['make_stamp', 'coyoteaccelerator:write_hls']
self._writer_flow = register_flow('write', writer_passes, requires=['vitis:ip'], backend=self.name)

ip_flow_requirements = get_flow('vitis:ip').requires.copy()
self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)

def compile(self, model):
"""
Compiles the hls4ml model for software emulation

Args:
model (ModelGraph): hls4ml model to synthesize

Return:
lib_name (str): The name of the compiled library
"""
lib_name = None
ret_val = subprocess.run(
['./build_lib.sh'],
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=model.config.get_output_dir(),
)
if ret_val.returncode != 0:
print(ret_val.stdout)
raise Exception(f'Failed to compile project "{model.config.get_project_name()}"')
lib_name = '{}/build/{}-{}.so'.format(
model.config.get_output_dir(), model.config.get_project_name(), model.config.get_config_value('Stamp')
)

return lib_name

def build(
self,
model,
device: str = 'u55c',
reset: bool = False,
csim: bool = True,
synth: bool = True,
cosim: bool = False,
validation: bool = False,
csynth: bool = False,
bitfile: bool = False,
timing_opt: bool = False,
hls_clock_period: float = 4,
hls_clock_uncertainty: float = 27
):
"""
Synthesizes the hls4ml model bitstream as part of the Coyote shell
and compiles the host-side software to control the FPGA and run model inference

Args:
model (ModelGraph): hls4ml model to synthesize
device (str, optional): Target Alveo FPGA card; currently supported u55c, u280 and u250
reset (bool, optional): Reset HLS project, if a previous one is found
csim (bool, optional): Run C-Simulation of the HLS project
synth (bool, optional): Run HLS synthesis
cosim (bool, optional): Run HLS co-simulation
validation (bool, optional): Validate results between C-Sim and Co-Sim
csynth (bool, optional): Run Coyote synthesis using Vivado, which will synthesize the model in a vFPGA
bitfile (bool, optional): Generate Coyote bitstream
timing_opt (bool, optional): Run additional optimizations when running PnR during bitstream generation
hls_clock_period (float, optional): Clock period to be used for HLS synthesis
hls_clock_uncertainty (float, optional): Clock uncertainty to be used for HLS synthesis

NOTE: Currently, the hardware will synthesize with a default clock period of 4ns / 250 MHz frequency,
since this is the default frequency of Coyote (since the XDMA core defaults to 250 MHz). Coyote allows
one to specify a different clock period for the model and use a clock-domain crossing (CDC) between the
XDMA region and the model. This option is currently not exposed as part of the hls4ml backend, but advanced
users can easily set in the the CMake configuration of Coyote.

NOTE: While the hardware will synthesize at 250 MHz, users can optionally pass a different HLS clock period
This is primarily a work-around when HLS synthesize a kernel that doesn't meet timing during PnR.
The "trick" is to run HLS synthesis at a higher clock frequency then (or provide higher uncertainty)

TODO: Add functionality to parse synthesis reports
"""
curr_dir = os.getcwd()

# Synthesize hardware
cmake_cmd = (
f'cmake ../../ '
f'-DFLOW=hw '
f'-DFDEV_NAME={device} '
f'-DBUILD_OPT={int(timing_opt)} '
f'-DEN_HLS_RESET={int(reset)} '
f'-DEN_HLS_CSIM={int(csim)} '
f'-DEN_HLS_SYNTH={int(synth)} '
f'-DEN_HLS_COSIM={int(cosim)} '
f'-DEN_HLS_VALIDATION={int(validation)} '
f'-DHLS_CLOCK_PERIOD={hls_clock_period} '
f'-DHLS_CLOCK_UNCERTAINTY="{str(hls_clock_uncertainty)}%"'
)

if not os.path.exists(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_hw'):
os.mkdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_hw')
os.chdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_hw')
os.system(cmake_cmd)

if bitfile:
os.system('make project && make bitgen')
elif csynth:
os.system('make project && make synth')
else:
os.system('make project')

os.chdir(curr_dir)

# Compile host software
cmake_cmd = 'cmake ../../ -DFLOW=sw'
if not os.path.exists(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_sw'):
os.mkdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_sw')
os.chdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_sw')
os.system(cmake_cmd)
os.system('make')
os.chdir(curr_dir)

104 changes: 104 additions & 0 deletions hls4ml/backends/coyote_accelerator/coyote_accelerator_overlay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import time
import ctypes
import logging
import numpy as np

class CoyoteOverlay:
"""
CoyoteOverlay class, similar to NeuralNetworkOverlay for the VivadoAccelerator backend
This class can be used to run model inference on the FPGA with the CoyoteAccelerator backend
"""
def __init__(self, path: str, project_name: str = 'myproject'):
"""
Default constructor

Args:
path (str): Path to the hls4ml folder, as specified in convert_model(...)
project_name (str, optional): hls4ml model name, if different than myproject
"""

self.path = path
self.project_name = project_name

# Set up dynamic C library
self.coyote_lib = ctypes.cdll.LoadLibrary(
f'{self.path}/build/{self.project_name}_cyt_sw/lib/libCoyoteInference.so'
)

self.coyote_lib.init_model_inference.argtypes = [ctypes.c_uint, ctypes.c_uint, ctypes.c_uint]
self.coyote_lib.init_model_inference.restype = ctypes.POINTER(ctypes.c_void_p)

self.coyote_lib.flush.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
self.coyote_lib.predict.argtypes = [ctypes.POINTER(ctypes.c_void_p)]

self.coyote_lib.get_inference_predictions.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_uint]
self.coyote_lib.get_inference_predictions.restype = ctypes.POINTER(ctypes.c_float)

self.coyote_lib.free_model_inference.argtypes = [ctypes.POINTER(ctypes.c_void_p)]

def program_hacc_fpga(self):
"""
Utility function for loading the Coyote-hls4ml bitstream and driver
on the ETH Zurich Heteregenous Accelerate Compute Cluster (HACC)
On other clusters, users would need to manually load the bitstream and driver
Gudance on this is specified in Coyote docs.
"""
os.system(
f'cd {self.path}/Coyote/driver && '
f'make && '
f'cd ../util && '
f'bash program_hacc_local.sh ../../build/{self.project_name}_cyt_hw/bitstreams/cyt_top.bit ../driver/build/coyote_driver.ko'
)

def predict(self, X: np.array, y_shape: tuple, batch_size: int = 1):
"""
Run model inference

Args:
X (np.array): Input data
y_shape (tuple): Shape of the output; used for allocating sufficient memory for the output
batch_size (int, optional): Inference batch size
"""
if len(X.shape) == 1:
X = np.array([X])
if not (isinstance(X.dtype, float) or isinstance(X.dtype, np.float32)):
logging.warning('CoyoteOverlay only supports (for now) floating-point inputs; casting input data to float')
X = X.astype(np.float32)
y = np.empty((len(X), *y_shape))
np_pointer_nd = np.ctypeslib.ndpointer(dtype=np.float32, ndim=len(X[0].shape), flags='C')
self.coyote_lib.set_inference_data.argtypes = [ctypes.POINTER(ctypes.c_void_p), np_pointer_nd, ctypes.c_uint]

model = self.coyote_lib.init_model_inference(batch_size, int(np.prod(X[0].shape)), int(np.prod(y_shape)))

cnt = 0
avg_latency = 0
avg_throughput = 0
total_batches = 0
for x in X:
self.coyote_lib.set_inference_data(model, x, cnt)
cnt += 1
if cnt == batch_size:
self.coyote_lib.flush(model)

ts = time.time_ns()
self.coyote_lib.predict(model)
te = time.time_ns()

time_taken = te - ts
avg_latency += (time_taken / 1e3)
avg_throughput += (batch_size / (time_taken * 1e-9))

for j in range(batch_size):
tmp = self.coyote_lib.get_inference_predictions(model, j)
y[total_batches * batch_size + j] = np.ctypeslib.as_array(tmp, shape=y_shape)

cnt = 0
total_batches += 1

self.coyote_lib.free_model_inference(model)
print(f'Batch size: {batch_size}; batches processed: {total_batches}')
print(f'Mean latency: {round(avg_latency / total_batches, 3)}us (inference only)')
print(f'Mean throughput: {round(avg_throughput / total_batches, 1)} samples/s (inference only)')

return y
Empty file.
1 change: 1 addition & 0 deletions hls4ml/contrib/Coyote
Submodule Coyote added at 292ec1
44 changes: 44 additions & 0 deletions hls4ml/templates/coyote_accelerator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
cmake_minimum_required(VERSION 3.5)
set(CYT_DIR ${CMAKE_SOURCE_DIR}/Coyote/)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CYT_DIR}/cmake)
find_package(CoyoteHW REQUIRED)
find_package(CoyoteSW REQUIRED)

set(FLOW "hw" CACHE STRING "Synthesize hardware (hw) or host software (sw)")

if(FLOW STREQUAL "hw")
project(myproject)
set(EN_STRM 1)
set(N_STRM_AXI 1)
set(N_REGIONS 1)

validation_checks_hw()
load_apps (
VFPGA_C0_0 "src"
)
create_hw()
endif()

if(FLOW STREQUAL "sw")
project(
CoyoteInference
VERSION 1.0.0
DESCRIPTION "CoyoteInference library"
)
set(CYT_INCLUDE_PATH ${CYT_DIR}/sw/include)
add_library(CoyoteInference SHARED "${CMAKE_SOURCE_DIR}/src/host_libs.cpp" "${CMAKE_SOURCE_DIR}/src/host_libs.hpp")
target_include_directories(CoyoteInference PUBLIC ${CYT_INCLUDE_PATH})
target_link_libraries(CoyoteInference PUBLIC Coyote)
target_link_directories(CoyoteInference PUBLIC /usr/local/lib)

project(myproject)
set(EXEC test)
set(TARGET_DIR "${CMAKE_SOURCE_DIR}/src/")
add_executable(${EXEC} ${TARGET_DIR}/myproject_host.cpp)
target_link_libraries(${EXEC} PUBLIC Coyote)
target_link_libraries(${EXEC} PUBLIC CoyoteInference)
target_link_directories(${EXEC} PUBLIC /usr/local/lib)
target_include_directories(${EXEC} PUBLIC src/hls/model_wrapper/firmware/)
target_include_directories(${EXEC} PUBLIC src/hls/model_wrapper/firmware/ap_types)

endif()
Loading
Loading