Skip to content

Support for floating point types #1307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
from hls4ml.model.types import (
ExponentPrecisionType,
FixedPrecisionType,
FloatPrecisionType,
IntegerPrecisionType,
PrecisionType,
RoundingMode,
SaturationMode,
StandardFloatPrecisionType,
UnspecifiedPrecisionType,
XnorPrecisionType,
)
Expand Down Expand Up @@ -343,11 +345,22 @@ def convert_precision_string(cls, precision):
if precision.lower() == 'auto':
return cls._convert_auto_type(precision)

if precision in ['float', 'double', 'half', 'bfloat16'] or precision.startswith(
('ap_float', 'ac_std_float', 'std_float')
):
return cls._convert_standard_float_type(precision)

if precision.startswith('ac_float'):
return cls._convert_ac_float_type(precision)

if precision.startswith('ac_'):
return cls._convert_ac_type(precision)
else:

if precision.startswith(('ap_', 'fixed', 'ufixed', 'int', 'uint')): # We parse AP notation even without 'ap_' prefix
return cls._convert_ap_type(precision)

raise ValueError(f'Unsupported precision type: {precision}')

@classmethod
def _convert_ap_type(cls, precision):
'''
Expand Down Expand Up @@ -416,6 +429,44 @@ def _convert_ac_type(cls, precision):
elif 'int' in precision:
return IntegerPrecisionType(width, signed)

@classmethod
def _convert_standard_float_type(cls, precision):
# Some default values
if precision == 'float':
return StandardFloatPrecisionType(width=32, exponent=8, use_cpp_type=True)
if precision == 'double':
return StandardFloatPrecisionType(width=64, exponent=11, use_cpp_type=True)
if precision == 'half':
return StandardFloatPrecisionType(width=16, exponent=5, use_cpp_type=True)
if precision == 'bfloat16':
return StandardFloatPrecisionType(width=16, exponent=8, use_cpp_type=True)

# If it is a float type, parse the width and exponent
bits = re.search('.+<(.+?)>', precision).group(1).split(',')
if len(bits) == 2:
width = int(bits[0])
exponent = int(bits[1])
return StandardFloatPrecisionType(width=width, exponent=exponent, use_cpp_type=False)
else:
raise ValueError(f'Invalid standard float precision format: {precision}')

@classmethod
def _convert_ac_float_type(cls, precision):
# If it is a float type, parse the width and exponent
bits = re.search('.+<(.+?)>', precision).group(1).split(',')
if len(bits) == 3 or len(bits) == 4:
mantissa = int(bits[0])
integer = int(bits[1])
exponent = int(bits[2])
width = mantissa + exponent
if len(bits) == 4:
round_mode = RoundingMode.from_string(bits[3])
else:
round_mode = None
return FloatPrecisionType(width=width, integer=integer, exponent=exponent, rounding_mode=round_mode)
else:
raise ValueError(f'Invalid ac_float precision format: {precision}')

@classmethod
def _convert_auto_type(cls, precision):
'''
Expand Down
55 changes: 52 additions & 3 deletions hls4ml/backends/fpga/fpga_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
ExponentPrecisionType,
ExponentType,
FixedPrecisionType,
FloatPrecisionType,
IntegerPrecisionType,
NamedType,
PackedType,
StandardFloatPrecisionType,
XnorPrecisionType,
)

Expand Down Expand Up @@ -51,6 +53,21 @@ def definition_cpp(self):
return typestring


class APFloatPrecisionDefinition(PrecisionDefinition):
def definition_cpp(self):
raise NotImplementedError(
'FloatPrecisionType is not supported in AP type precision definitions. Use StandardFloatPrecisionType instead.'
)


class APStandardFloatPrecisionDefinition(PrecisionDefinition):
def definition_cpp(self):
typestring = str(self)
if typestring.startswith('std_float'):
typestring = typestring.replace('std_float', 'ap_float')
return typestring


class ACIntegerPrecisionDefinition(PrecisionDefinition):
def definition_cpp(self):
typestring = f'ac_int<{self.width}, {str(self.signed).lower()}>'
Expand Down Expand Up @@ -90,12 +107,40 @@ def definition_cpp(self):
return typestring


class ACFloatPrecisionDefinition(PrecisionDefinition):
def _rounding_mode_cpp(self, mode):
if mode is not None:
return 'AC_' + str(mode)

def definition_cpp(self):
args = [
self.width,
self.integer,
self.exponent,
self._rounding_mode_cpp(self.rounding_mode),
]
if args[3] == 'AC_TRN':
# This is the default, so we won't write the full definition for brevity
args[3] = None
args = ','.join([str(arg) for arg in args[:5] if arg is not None])
typestring = f'ac_float<{args}>'
return typestring


class ACStandardFloatPrecisionDefinition(PrecisionDefinition):
def definition_cpp(self):
typestring = str(self)
if typestring.startswith('std_float'):
typestring = 'ac_' + typestring
return typestring


class PrecisionConverter:
def convert(self, precision_type):
raise NotImplementedError


class FixedPrecisionConverter(PrecisionConverter):
class FPGAPrecisionConverter(PrecisionConverter):
def __init__(self, type_map, prefix):
self.type_map = type_map
self.prefix = prefix
Expand All @@ -120,25 +165,29 @@ def convert(self, precision_type):
raise Exception(f'Cannot convert precision type to {self.prefix}: {precision_type.__class__.__name__}')


class APTypeConverter(FixedPrecisionConverter):
class APTypeConverter(FPGAPrecisionConverter):
def __init__(self):
super().__init__(
type_map={
FixedPrecisionType: APFixedPrecisionDefinition,
IntegerPrecisionType: APIntegerPrecisionDefinition,
FloatPrecisionType: APFloatPrecisionDefinition,
StandardFloatPrecisionType: APStandardFloatPrecisionDefinition,
ExponentPrecisionType: APIntegerPrecisionDefinition,
XnorPrecisionType: APIntegerPrecisionDefinition,
},
prefix='AP',
)


class ACTypeConverter(FixedPrecisionConverter):
class ACTypeConverter(FPGAPrecisionConverter):
def __init__(self):
super().__init__(
type_map={
FixedPrecisionType: ACFixedPrecisionDefinition,
IntegerPrecisionType: ACIntegerPrecisionDefinition,
FloatPrecisionType: ACFloatPrecisionDefinition,
StandardFloatPrecisionType: ACStandardFloatPrecisionDefinition,
ExponentPrecisionType: ACIntegerPrecisionDefinition,
XnorPrecisionType: ACIntegerPrecisionDefinition,
},
Expand Down
10 changes: 8 additions & 2 deletions hls4ml/backends/oneapi/oneapi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

from hls4ml.backends.fpga.fpga_types import (
ACFixedPrecisionDefinition,
ACFloatPrecisionDefinition,
ACIntegerPrecisionDefinition,
FixedPrecisionConverter,
ACStandardFloatPrecisionDefinition,
FloatPrecisionType,
FPGAPrecisionConverter,
HLSTypeConverter,
NamedTypeConverter,
PrecisionDefinition,
StandardFloatPrecisionType,
TypeDefinition,
TypePrecisionConverter,
VariableDefinition,
Expand All @@ -35,12 +39,14 @@ def definition_cpp(self):
return typestring


class OneAPIACTypeConverter(FixedPrecisionConverter):
class OneAPIACTypeConverter(FPGAPrecisionConverter):
def __init__(self):
super().__init__(
type_map={
FixedPrecisionType: ACFixedPrecisionDefinition,
IntegerPrecisionType: ACIntegerPrecisionDefinition,
FloatPrecisionType: ACFloatPrecisionDefinition,
StandardFloatPrecisionType: ACStandardFloatPrecisionDefinition,
ExponentPrecisionType: ACExponentPrecisionDefinition,
XnorPrecisionType: ACIntegerPrecisionDefinition,
},
Expand Down
123 changes: 122 additions & 1 deletion hls4ml/model/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,126 @@ def __str__(self):
return typestring


class FloatPrecisionType(PrecisionType):
"""
Class representing a floating-point precision type.

This type is equivalent to ac_float HLS types. If the use of C++ equivalent types is required, see
``StandardFloatPrecisionType``.

Args:
width (int, optional): Total number of bits used. Defaults to 33.
integer (int, optional): Number of bits used for the integer part. Defaults to 2.
exponent (int, optional): Number of bits used for the exponent. Defaults to 8.
"""

def __init__(self, width=33, integer=2, exponent=8, rounding_mode=None):
super().__init__(width=width, signed=True)
self.exponent = exponent
self.integer = integer # If None, will be set to width - exponent - 1
self.rounding_mode = rounding_mode

@property
def rounding_mode(self):
return self._rounding_mode

@rounding_mode.setter
def rounding_mode(self, mode):
if mode is None:
self._rounding_mode = RoundingMode.TRN
elif isinstance(mode, str):
self._rounding_mode = RoundingMode.from_string(mode)
else:
self._rounding_mode = mode

def __str__(self):
args = [self.width - self.exponent, self.integer, self.exponent, self.rounding_mode]
args = ','.join([str(arg) for arg in args])
typestring = f'float<{args}>'
return typestring

def __eq__(self, other: object) -> bool:
if isinstance(other, FloatPrecisionType):
eq = super().__eq__(other)
eq = eq and self.integer == other.integer
eq = eq and self.exponent == other.exponent
eq = eq and self.rounding_mode == other.rounding_mode
return eq

return False

def __hash__(self) -> int:
return super().__hash__() ^ hash((self.integer, self.exponent, self.rounding_mode))

def serialize_state(self):
state = super().serialize_state()
state.update(
{
'integer': self.integer,
'exponent': self.exponent,
'rounding_mode': str(self.rounding_mode),
}
)
return state


class StandardFloatPrecisionType(PrecisionType):
"""
Class representing a floating-point precision type.

This type is equivalent to ap_float and ac_std_float HLS types. <32,8> corresponds to a 'float' type in C/C++. <64,11>
corresponds to a 'double' type in C/C++. <16,5> corresponds to a 'half' type in C/C++. <16,8> corresponds to a
'bfloat16' type in C/C++.

Args:
width (int, optional): Total number of bits used. Defaults to 32.
exponent (int, optional): Number of bits used for the exponent. Defaults to 8.
use_cpp_type (bool, optional): Use C++ equivalent types if available. Defaults to ``True``.
"""

def __init__(self, width=32, exponent=8, use_cpp_type=True):
super().__init__(width=width, signed=True)
self.exponent = exponent
self.use_cpp_type = use_cpp_type

def __str__(self):
if self._check_cpp_type(32, 8):
typestring = 'float'
elif self._check_cpp_type(64, 11):
typestring = 'double'
elif self._check_cpp_type(16, 5):
typestring = 'half'
elif self._check_cpp_type(16, 8):
typestring = 'bfloat16'
else:
typestring = f'std_float<{self.width},{self.exponent}>'
return typestring

def _check_cpp_type(self, width, exponent):
return self.use_cpp_type and self.width == width and self.exponent == exponent

def __eq__(self, other: object) -> bool:
if isinstance(other, FloatPrecisionType):
eq = super().__eq__(other)
eq = eq and self.exponent == other.exponent
return eq

return False

def __hash__(self) -> int:
return super().__hash__() ^ hash(self.exponent)

def serialize_state(self):
state = super().serialize_state()
state.update(
{
'exponent': self.exponent,
'use_cpp_type': self.use_cpp_type,
}
)
return state


class UnspecifiedPrecisionType(PrecisionType):
"""
Class representing an unspecified precision type.
Expand Down Expand Up @@ -592,7 +712,8 @@ def update_precision(self, new_precision):
elif isinstance(new_precision, FixedPrecisionType):
decimal_spaces = max(0, new_precision.fractional)
self.precision_fmt = f'{{:.{decimal_spaces}f}}'

elif isinstance(new_precision, (FloatPrecisionType, StandardFloatPrecisionType)):
self.precision_fmt = '{:.16f}' # Not ideal, but should be enough for most cases
else:
raise RuntimeError(f"Unexpected new precision type: {new_precision}")

Expand Down
1 change: 1 addition & 0 deletions hls4ml/templates/oneapi/firmware/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define DEFINES_H_

#include <sycl/ext/intel/ac_types/ac_fixed.hpp>
#include <sycl/ext/intel/ac_types/ac_float.hpp>
#include <sycl/ext/intel/ac_types/ac_int.hpp>
#include <sycl/ext/intel/fpga_extensions.hpp>
#include <sycl/sycl.hpp>
Expand Down
2 changes: 2 additions & 0 deletions hls4ml/templates/quartus/firmware/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef __INTELFPGA_COMPILER__

#include "ac_fixed.h"
#include "ac_float.h"
#include "ac_int.h"
#define hls_register

Expand All @@ -24,6 +25,7 @@ template <typename T> using stream_out = nnet::stream<T>;
#else

#include "HLS/ac_fixed.h"
#include "HLS/ac_float.h"
#include "HLS/ac_int.h"
#include "HLS/hls.h"

Expand Down
Loading
Loading