Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def __init__(self, name):
Dense,
Conv1D,
Conv2D,
SeparableConv1D,
SeparableConv2D,
Pooling1D,
Pooling2D,
GlobalPooling1D,
Expand All @@ -79,6 +77,16 @@ def __init__(self, name):
attrs.append(ConfigurableAttribute('reuse_factor', default=1))
self.attribute_map[layer] = attrs

# seperable is kind of special because it is effectively two layers that will be split
for layer in (SeparableConv1D, SeparableConv2D):
attrs = self.attribute_map.get(layer, [])
attrs.append(TypeAttribute('depthwise_accum'))
attrs.append(TypeAttribute('pointwise_accum'))
attrs.append(TypeAttribute('depthwise_result'))
attrs.append(ConfigurableAttribute('depthwise_reuse_factor', default=1))
attrs.append(ConfigurableAttribute('pointwise_reuse_factor', default=1))
self.attribute_map[layer] = attrs

act_attrs = self.attribute_map.get(Activation, [])
act_attrs.append(ConfigurableAttribute('table_size', default=1024))
act_attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8)))
Expand Down
1 change: 1 addition & 0 deletions hls4ml/backends/fpga/passes/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class GenerateConvIm2col(OptimizerPass):
'''Generates tcode for im2col step of 1D/2d convolution'''

# Note, DepthwizeConv1D/2D also matches because it inherits from Conv1D/2D
def match(self, node):
return (
isinstance(node, (Conv1D, Conv2D, SeparableConv1D, SeparableConv2D))
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def format(self, node):
# Override bias and bias_t since these are zeros in depthwise step of SepConv1D
params['bias'] = params['zero_bias']
params['bias_t'] = params['zero_bias_t']
params['n_filt'] = params['n_chan'] # In depthwise step n_chan == n_filt
params['n_filt'] = params['n_chan'] * node.get_attr('depth_multiplier') # In depthwise step n_chan == n_filt
params['dilation'] = node.get_attr('dilation', 1)
params['nzeros'] = node.get_weights('depthwise').nzeros
params['index'] = str(node.index) + '_depthwise'
Expand Down
32 changes: 26 additions & 6 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Conv1D,
Conv2D,
Dense,
DepthwiseConv1D,
DepthwiseConv2D,
Embedding,
GarNet,
Expand Down Expand Up @@ -74,12 +75,6 @@ def _register_layer_attributes(self):
attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer'))
self.attribute_map[layer] = attrs

sep_conv_layers = [SeparableConv1D, SeparableConv2D]
for layer in sep_conv_layers:
attrs = self.attribute_map.get(layer, [])
attrs.append(TypeAttribute('dw_output', default=FixedPrecisionType(18, 8)))
self.attribute_map[layer] = attrs

def _register_flows(self):
initializers = self._get_layer_initializers()
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
Expand Down Expand Up @@ -359,6 +354,31 @@ def init_sepconv1d(self, layer):
dw_output_t = NamedType(dw_out_name, dw_out_precision)
layer.set_attr('dw_output_t', dw_output_t)

@layer_optimizer(DepthwiseConv1D)
def init_depconv1d(self, layer):
if layer.model.config.is_resource_strategy(layer):
layer.set_attr('strategy', 'resource')
n_in, n_out = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
else:
layer.set_attr('strategy', 'latency')

out_width = layer.get_output_variable().shape[0]
chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1)
valid_pf = self.get_valid_conv_partition_splits(1, out_width)
if chosen_pf not in valid_pf:
closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf)
valid_pf_str = ','.join(map(str, valid_pf))
print(
f'WARNING: Invalid ParallelizationFactor={chosen_pf} in layer "{layer.name}".'
f'Using ParallelizationFactor={closest_pf} instead. Valid ParallelizationFactor(s): {valid_pf_str}.'
)
else:
closest_pf = chosen_pf
layer.set_attr('n_partitions', out_width // closest_pf)

layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())

@layer_optimizer(Conv2D)
def init_conv2d(self, layer):
if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
Expand Down
10 changes: 8 additions & 2 deletions hls4ml/converters/keras/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader):

layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')

if 'depth_multiplier' in keras_layer['config']:
layer['depth_multiplier'] = keras_layer['config']['depth_multiplier']

if 'filters' in keras_layer['config']:
layer['n_filt'] = keras_layer['config']['filters']
else:
layer['n_filt'] = layer['n_chan']
layer['n_filt'] = layer['n_chan'] * layer.get('depth_multiplier')
layer['filt_width'] = keras_layer['config']['kernel_size'][0]
layer['stride_width'] = keras_layer['config']['strides'][0]
layer['padding'] = keras_layer['config']['padding']
Expand Down Expand Up @@ -60,10 +63,13 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader):

layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')

if 'depth_multiplier' in keras_layer['config']:
layer['depth_multiplier'] = keras_layer['config']['depth_multiplier']

if 'filters' in keras_layer['config']:
layer['n_filt'] = keras_layer['config']['filters']
else:
layer['n_filt'] = layer['n_chan']
layer['n_filt'] = layer['n_chan'] * layer.get('depth_multiplier')
layer['filt_height'] = keras_layer['config']['kernel_size'][0]
layer['filt_width'] = keras_layer['config']['kernel_size'][1]
layer['stride_height'] = keras_layer['config']['strides'][0]
Expand Down
100 changes: 74 additions & 26 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def get_layer_config(self, layer):

return layer_config

def set_name_config(self, name, config):
"""sets hls_config["LayerName"][name] = config"""
hls_config = self.config['HLSConfig']
layer_config = hls_config.setdefault('LayerName', {})
layer_config[name] = config

def get_precision(self, layer, var='default'):
precision = self.layer_name_precision.get(layer.name.lower() + '_' + var)
type_name = layer.name.lower() + '_' + var + '_t'
Expand Down Expand Up @@ -192,6 +198,35 @@ def get_compression(self, layer):

return compression

def parse_name_config(self, layer_name, layer_cfg):
"""This is used by _parse_hls_config below, but also in optimizers when a new layer config is created"""
precision_cfg = layer_cfg.get('Precision')
if isinstance(precision_cfg, dict):
for var, precision in precision_cfg.items():
self.layer_name_precision[layer_name.lower() + '_' + var] = precision
else:
self.layer_name_precision[layer_name.lower() + '_default'] = precision_cfg

rf = layer_cfg.get('ReuseFactor')
if rf is not None:
self.layer_name_rf[layer_name.lower()] = rf

targ_cycles = layer_cfg.get('TargetCycles')
if targ_cycles is not None:
self.layer_name_targ_cycles[layer_name.lower()] = targ_cycles

strategy = layer_cfg.get('Strategy')
if strategy is not None:
self.layer_name_strategy[layer_name.lower()] = strategy

conv_implementation = layer_cfg.get('ConvImplementation')
if conv_implementation is not None:
self.layer_name_conv_implementation[layer_name.lower()] = conv_implementation

compression = layer_cfg.get('Compression')
if compression is not None:
self.layer_name_compression[layer_name.lower()] = bool(compression)

def get_writer_config(self):
return self.writer_config

Expand Down Expand Up @@ -267,32 +302,7 @@ def _parse_hls_config(self):
layer_name_cfg = hls_config.get('LayerName')
if layer_name_cfg is not None:
for layer_name, layer_cfg in layer_name_cfg.items():
precision_cfg = layer_cfg.get('Precision')
if isinstance(precision_cfg, dict):
for var, precision in precision_cfg.items():
self.layer_name_precision[layer_name.lower() + '_' + var] = precision
else:
self.layer_name_precision[layer_name.lower() + '_default'] = precision_cfg

rf = layer_cfg.get('ReuseFactor')
if rf is not None:
self.layer_name_rf[layer_name.lower()] = rf

targ_cycles = layer_cfg.get('TargetCycles')
if targ_cycles is not None:
self.layer_name_targ_cycles[layer_name.lower()] = targ_cycles

strategy = layer_cfg.get('Strategy')
if strategy is not None:
self.layer_name_strategy[layer_name.lower()] = strategy

conv_implementation = layer_cfg.get('ConvImplementation')
if conv_implementation is not None:
self.layer_name_conv_implementation[layer_name.lower()] = conv_implementation

compression = layer_cfg.get('Compression')
if compression is not None:
self.layer_name_compression[layer_name.lower()] = bool(compression)
self.parse_name_config(layer_name, layer_cfg)

def _validate_hls_config(self):
use_dataflow = False
Expand Down Expand Up @@ -617,6 +627,44 @@ def replace_node(self, old_node, new_node):
self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items())
self._update_model_outputs()

def split_node(self, old_node, new_node1, new_node2):
"""Replace an existing node in the graph with two nodes in sequence.

Args:
old_node (Layer): The node to replace
new_node1 (Layer): The first new node in sequence
new_node2 (Layer): The second new node in sequence

"""

# fmt: off
assert len(new_node1.inputs) == len(old_node.inputs), \
f'{new_node1.name} and {old_node.name} have different number of inputs'
assert len(new_node2.outputs) == len(old_node.outputs), \
f'{new_node2.name} and {old_node.name} have different number of outputs'
# fmt: on

repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node2.outputs)}
repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node1.inputs)})

for node in self.graph.values():
for i, n in enumerate(node.inputs):
if n in repl:
node.inputs[i] = repl[n]
for i, n in enumerate(node.outputs):
if n in repl:
node.outputs[i] = repl[n]

new_graph = OrderedDict()
for key, value in self.graph.items():
if key == old_node.name:
new_graph[new_node1.name] = new_node1
new_graph[new_node2.name] = new_node2
else:
new_graph[key] = value
self.graph = new_graph
self._update_model_outputs()

def _update_model_outputs(self):
'''Update the model outputs

Expand Down
71 changes: 67 additions & 4 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def _set_accum_t(self):
accum_t = NamedType(*reversed(self.model.config.get_precision(self, 'accum')))
self.set_attr('accum_t', accum_t)

def _set_type_t(self, name):
has_type_t = any(a for a in self.expected_attributes if a.name == name + '_t' and isinstance(a, TypeAttribute))
if has_type_t:
type_t = NamedType(*reversed(self.model.config.get_precision(self, name)))
self.set_attr(name + '_t', type_t)

def get_input_node(self, input_name=None):
if input_name is None:
if len(self.inputs) > 0:
Expand Down Expand Up @@ -446,6 +452,7 @@ class SeparableConv1D(Layer):
Attribute('out_width'),
Attribute('n_chan'),
Attribute('n_filt'),
Attribute('depth_multiplier', default=1),
Attribute('filt_width'),
Attribute('stride_width'),
Attribute('pad_left'),
Expand Down Expand Up @@ -476,14 +483,35 @@ def initialize(self):

self.add_bias(quantizer=self.get_attr('bias_quantizer'))

# set the needed types if needed
self._set_type_t('pointwise_accum')
self._set_type_t('depthwise_accum')
self._set_type_t('depthwise_result')


class DepthwiseConv1D(Conv1D):
_expected_attributes = [
Attribute('in_width'),
Attribute('out_width'),
Attribute('n_chan'),
Attribute('depth_multiplier', default=1),
Attribute('n_filt'), # = n_chan * depth_multiplier
Attribute('filt_width'),
Attribute('stride_width'),
Attribute('pad_left'),
Attribute('pad_right'),
WeightAttribute('weight'),
WeightAttribute('bias'),
TypeAttribute('weight'),
TypeAttribute('bias'),
]

def initialize(self):
if self.get_attr('data_format') == 'channels_last':
shape = [self.attributes['out_width'], self.attributes['n_chan']]
shape = [self.attributes['out_width'], self.attributes['n_filt']]
dims = [f'OUT_HEIGHT_{self.index}', f'N_CHAN_{self.index}']
else:
shape = [self.attributes['n_chan'], self.attributes['out_width']]
shape = [self.attributes['n_filt'], self.attributes['out_width']]
dims = [f'N_CHAN_{self.index}', f'OUT_WIDTH_{self.index}']
self.add_output_variable(shape, dims)

Expand Down Expand Up @@ -588,6 +616,7 @@ class SeparableConv2D(Layer):
Attribute('out_width'),
Attribute('n_chan'),
Attribute('n_filt'),
Attribute('depth_multiplier', default=1),
Attribute('filt_height'),
Attribute('filt_width'),
Attribute('stride_height'),
Expand Down Expand Up @@ -622,14 +651,48 @@ def initialize(self):

self.add_bias(quantizer=self.get_attr('bias_quantizer'))

self._set_type_t('pointwise_accum')
self._set_type_t('depthwise_accum')
self._set_type_t('depthwise_result')


class DepthwiseConv2D(Conv2D):
_expected_attributes = [
Attribute('in_height'),
Attribute('in_width'),
Attribute('out_height'),
Attribute('out_width'),
Attribute('n_chan'),
Attribute('depth_multiplier', default=1),
Attribute('n_filt'), # = n_chan * depth_multiplier
Attribute('filt_height'),
Attribute('filt_width'),
Attribute('stride_height'),
Attribute('stride_width'),
Attribute('pad_top'),
Attribute('pad_bottom'),
Attribute('pad_left'),
Attribute('pad_right'),
WeightAttribute('weight'),
WeightAttribute('bias'),
TypeAttribute('weight'),
TypeAttribute('bias'),
]

def initialize(self):
if self.get_attr('data_format') == 'channels_last':
shape = [self.attributes['out_height'], self.attributes['out_width'], self.attributes['n_chan']]
shape = [
self.attributes['out_height'],
self.attributes['out_width'],
self.attributes['n_filt'],
]
dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}']
else:
shape = [self.attributes['n_chan'], self.attributes['out_height'], self.attributes['out_width']]
shape = [
self.attributes['n_filt'],
self.attributes['out_height'],
self.attributes['out_width'],
]
dims = [f'N_CHAN_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}']
self.add_output_variable(shape, dims)

Expand Down
1 change: 1 addition & 0 deletions hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
register_flow(
'convert',
[
'seperable_to_depthwise_and_conv', # has to be before precision inference
'infer_precision_types',
'channels_last_converter',
'remove_transpose_before_flatten',
Expand Down
Loading