Skip to content

Commit 2cb6fe1

Browse files
Add functionality to use granularity option also for pytorch models (#1051)
* allow granularity options in pytorch parser * pre-commit * [pre-commit.ci] auto fixes from pre-commit hooks * add torch to setup? * add torch to setup2? * add torch to setup3? * add torch to requirements * fix failing pytest * adapat new batchnorm pytests to changes in interface * addressing comments from Vladimir and Jovan * remvoving torch from requirements --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2898ab2 commit 2cb6fe1

10 files changed

+204
-100
lines changed

hls4ml/converters/__init__.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401
1111
from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401
1212
from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler
13+
14+
# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401
1315
from hls4ml.model import ModelGraph
1416
from hls4ml.utils.config import create_config
1517
from hls4ml.utils.symbolic_utils import LUTFunction
@@ -238,7 +240,6 @@ def convert_from_keras_model(
238240

239241
def convert_from_pytorch_model(
240242
model,
241-
input_shape,
242243
output_dir='my-hls-test',
243244
project_name='myproject',
244245
input_data_tb=None,
@@ -251,7 +252,6 @@ def convert_from_pytorch_model(
251252
252253
Args:
253254
model: PyTorch model to convert.
254-
input_shape (list): The shape of the input tensor. First element is the batch size, needs to be None
255255
output_dir (str, optional): Output directory of the generated HLS project. Defaults to 'my-hls-test'.
256256
project_name (str, optional): Name of the HLS project. Defaults to 'myproject'.
257257
input_data_tb (str, optional): String representing the path of input data in .npy or .dat format that will be
@@ -293,17 +293,16 @@ def convert_from_pytorch_model(
293293
config = create_config(output_dir=output_dir, project_name=project_name, backend=backend, **kwargs)
294294

295295
config['PytorchModel'] = model
296-
config['InputShape'] = input_shape
297296
config['InputData'] = input_data_tb
298297
config['OutputPredictions'] = output_data_tb
299298
config['HLSConfig'] = {}
300299

301300
if hls_config is None:
302301
hls_config = {}
303302

304-
model_config = hls_config.get('Model', None)
303+
model_config = hls_config.get('Model')
305304
config['HLSConfig']['Model'] = _check_model_config(model_config)
306-
305+
config['InputShape'] = hls_config.get('InputShape')
307306
_check_hls_config(config, hls_config)
308307

309308
return pytorch_to_hls(config)

hls4ml/converters/pytorch_to_hls.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def decorator(function):
102102
# ----------------------------------------------------------------
103103

104104

105-
def pytorch_to_hls(config):
105+
def parse_pytorch_model(config, verbose=True):
106106
"""Convert PyTorch model to hls4ml ModelGraph.
107107
108108
Args:
@@ -118,14 +118,15 @@ def pytorch_to_hls(config):
118118
# This is a list of dictionaries to hold all the layer info we need to generate HLS
119119
layer_list = []
120120

121-
print('Interpreting Model ...')
122-
121+
if verbose:
122+
print('Interpreting Model ...')
123123
reader = PyTorchFileReader(config) if isinstance(config['PytorchModel'], str) else PyTorchModelReader(config)
124124
if type(reader.input_shape) is tuple:
125125
input_shapes = [list(reader.input_shape)]
126126
else:
127127
input_shapes = list(reader.input_shape)
128-
input_shapes = [list(shape) for shape in input_shapes]
128+
# first element needs to 'None' as placeholder for the batch size, insert it if not present
129+
input_shapes = [[None] + list(shape) if shape[0] is not None else list(shape) for shape in input_shapes]
129130

130131
model = reader.torch_model
131132

@@ -151,7 +152,8 @@ def pytorch_to_hls(config):
151152
output_shape = None
152153

153154
# Loop through layers
154-
print('Topology:')
155+
if verbose:
156+
print('Topology:')
155157
layer_counter = 0
156158

157159
n_inputs = 0
@@ -226,13 +228,14 @@ def pytorch_to_hls(config):
226228
pytorch_class, layer_name, input_names, input_shapes, node, class_object, reader, config
227229
)
228230

229-
print(
230-
'Layer name: {}, layer type: {}, input shape: {}'.format(
231-
layer['name'],
232-
layer['class_name'],
233-
input_shapes,
231+
if verbose:
232+
print(
233+
'Layer name: {}, layer type: {}, input shape: {}'.format(
234+
layer['name'],
235+
layer['class_name'],
236+
input_shapes,
237+
)
234238
)
235-
)
236239
layer_list.append(layer)
237240

238241
assert output_shape is not None
@@ -288,7 +291,12 @@ def pytorch_to_hls(config):
288291
operation, layer_name, input_names, input_shapes, node, None, reader, config
289292
)
290293

291-
print('Layer name: {}, layer type: {}, input shape: {}'.format(layer['name'], layer['class_name'], input_shapes))
294+
if verbose:
295+
print(
296+
'Layer name: {}, layer type: {}, input shape: {}'.format(
297+
layer['name'], layer['class_name'], input_shapes
298+
)
299+
)
292300
layer_list.append(layer)
293301

294302
assert output_shape is not None
@@ -342,7 +350,12 @@ def pytorch_to_hls(config):
342350
operation, layer_name, input_names, input_shapes, node, None, reader, config
343351
)
344352

345-
print('Layer name: {}, layer type: {}, input shape: {}'.format(layer['name'], layer['class_name'], input_shapes))
353+
if verbose:
354+
print(
355+
'Layer name: {}, layer type: {}, input shape: {}'.format(
356+
layer['name'], layer['class_name'], input_shapes
357+
)
358+
)
346359
layer_list.append(layer)
347360

348361
assert output_shape is not None
@@ -351,6 +364,11 @@ def pytorch_to_hls(config):
351364
if len(input_layers) == 0:
352365
input_layers = None
353366

367+
return layer_list, input_layers
368+
369+
370+
def pytorch_to_hls(config):
371+
layer_list, input_layers = parse_pytorch_model(config)
354372
print('Creating HLS model')
355373
hls_model = ModelGraph(config, layer_list, inputs=input_layers)
356374
return hls_model

hls4ml/utils/config.py

+79
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def make_layer_config(layer):
269269

270270
def config_from_pytorch_model(
271271
model,
272+
input_shape,
272273
granularity='model',
273274
backend=None,
274275
default_precision='ap_fixed<16,6>',
@@ -284,6 +285,7 @@ def config_from_pytorch_model(
284285
285286
Args:
286287
model: PyTorch model
288+
input_shape (tuple or list of tuples): The shape of the input tensor, excluding the batch size.
287289
granularity (str, optional): Granularity of the created config. Defaults to 'model'.
288290
Can be set to 'model', 'type' and 'layer'.
289291
@@ -321,6 +323,83 @@ def config_from_pytorch_model(
321323
model_config['Strategy'] = 'Latency'
322324

323325
config['Model'] = model_config
326+
config['PytorchModel'] = model
327+
if not (isinstance(input_shape, tuple) or (isinstance(input_shape, list) and isinstance(input_shape[0], tuple))):
328+
raise Exception('Input shape must be tuple (single input) or list of tuples (multiple inputs)')
329+
config['InputShape'] = input_shape
330+
331+
if granularity.lower() not in ['model', 'type', 'name']:
332+
raise Exception(
333+
f'Invalid configuration granularity specified, expected "model", "type" or "name" got "{granularity}"'
334+
)
335+
336+
if backend is not None:
337+
backend = hls4ml.backends.get_backend(backend)
338+
339+
from hls4ml.converters.pytorch_to_hls import parse_pytorch_model
340+
341+
(
342+
layer_list,
343+
_,
344+
) = parse_pytorch_model(config, verbose=False)
345+
346+
def make_layer_config(layer):
347+
cls_name = layer['class_name']
348+
if 'config' in layer.keys():
349+
if 'activation' in layer['config'].keys():
350+
if layer['config']['activation'] == 'softmax':
351+
cls_name = 'Softmax'
352+
353+
layer_cls = hls4ml.model.layers.layer_map[cls_name]
354+
if backend is not None:
355+
layer_cls = backend.create_layer_class(layer_cls)
356+
357+
layer_config = {}
358+
359+
config_attrs = [a for a in layer_cls.expected_attributes if a.configurable]
360+
for attr in config_attrs:
361+
if isinstance(attr, hls4ml.model.attributes.TypeAttribute):
362+
precision_cfg = layer_config.setdefault('Precision', {})
363+
name = attr.name
364+
if name.endswith('_t'):
365+
name = name[:-2]
366+
if attr.default is None:
367+
precision_cfg[name] = default_precision
368+
else:
369+
precision_cfg[name] = str(attr.default)
370+
elif attr.name == 'reuse_factor':
371+
layer_config[attr.config_name] = default_reuse_factor
372+
else:
373+
if attr.default is not None:
374+
layer_config[attr.config_name] = attr.default
375+
376+
if layer['class_name'] == 'Input':
377+
dtype = layer['config']['dtype']
378+
if dtype.startswith('int') or dtype.startswith('uint'):
379+
typename = dtype[: dtype.index('int') + 3]
380+
width = int(dtype[dtype.index('int') + 3 :])
381+
layer_config['Precision']['result'] = f'ap_{typename}<{width}>'
382+
# elif bool, q[u]int, ...
383+
384+
return layer_config
385+
386+
if granularity.lower() == 'type':
387+
type_config = {}
388+
for layer in layer_list:
389+
if layer['class_name'] in type_config:
390+
continue
391+
layer_config = make_layer_config(layer)
392+
type_config[layer['class_name']] = layer_config
393+
394+
config['LayerType'] = type_config
395+
396+
elif granularity.lower() == 'name':
397+
name_config = {}
398+
for layer in layer_list:
399+
layer_config = make_layer_config(layer)
400+
name_config[layer['name']] = layer_config
401+
402+
config['LayerName'] = name_config
324403

325404
return config
326405

test/pytest/test_backend_config.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc):
3131
convert_fn = hls4ml.converters.convert_from_keras_model
3232
else:
3333
model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.ReLU())
34-
config = hls4ml.utils.config_from_pytorch_model(model)
34+
config = hls4ml.utils.config_from_pytorch_model(model, input_shape=(None, 1))
3535
convert_fn = hls4ml.converters.convert_from_pytorch_model
3636

3737
if clock_unc is not None:
@@ -42,16 +42,27 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc):
4242
test_dir = f'hls4mlprj_backend_config_{framework}_{backend}_part_{part}_period_{clock_period}_unc_{unc_str}'
4343
output_dir = test_root_path / test_dir
4444

45-
hls_model = convert_fn(
46-
model,
47-
input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer
48-
hls_config=config,
49-
output_dir=str(output_dir),
50-
backend=backend,
51-
part=part,
52-
clock_period=clock_period,
53-
clock_uncertainty=clock_unc,
54-
)
45+
if framework == "keras":
46+
hls_model = convert_fn(
47+
model,
48+
input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer
49+
hls_config=config,
50+
output_dir=str(output_dir),
51+
backend=backend,
52+
part=part,
53+
clock_period=clock_period,
54+
clock_uncertainty=clock_unc,
55+
)
56+
else:
57+
hls_model = convert_fn(
58+
model,
59+
hls_config=config,
60+
output_dir=str(output_dir),
61+
backend=backend,
62+
part=part,
63+
clock_period=clock_period,
64+
clock_uncertainty=clock_unc,
65+
)
5566

5667
hls_model.write()
5768

test/pytest/test_batchnorm_pytorch.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ def test_batchnorm(data, backend, io_type):
3939

4040
default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>'
4141

42-
config = hls4ml.utils.config_from_pytorch_model(model, default_precision=default_precision, granularity='name')
42+
config = hls4ml.utils.config_from_pytorch_model(
43+
model, (in_shape,), default_precision=default_precision, granularity='name'
44+
)
4345
output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}')
4446
hls_model = hls4ml.converters.convert_from_pytorch_model(
45-
model, (None, in_shape), backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
47+
model, backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
4648
)
4749
hls_model.compile()
4850

@@ -94,17 +96,20 @@ def test_batchnorm_fusion(fusion_data, backend, io_type):
9496
# We do not have an implementation of a transpose for io_stream, need to transpose inputs and outputs outside of hls4ml
9597
if io_type == 'io_stream':
9698
fusion_data = np.ascontiguousarray(fusion_data.transpose(0, 2, 1))
97-
config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='internal', transpose_outputs=False)
99+
config = hls4ml.utils.config_from_pytorch_model(
100+
model, (n_in, size_in_height), channels_last_conversion='internal', transpose_outputs=False
101+
)
98102
else:
99-
config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='full', transpose_outputs=True)
103+
config = hls4ml.utils.config_from_pytorch_model(
104+
model, (n_in, size_in_height), channels_last_conversion='full', transpose_outputs=True
105+
)
100106

101107
config['Model']['Strategy'] = 'Resource'
102108

103109
# conversion
104110
output_dir = str(test_root_path / f'hls4mlprj_block_{backend}_{io_type}')
105111
hls_model = hls4ml.converters.convert_from_pytorch_model(
106112
model,
107-
(None, n_in, size_in_height),
108113
hls_config=config,
109114
output_dir=output_dir,
110115
backend=backend,

test/pytest/test_merge_pytorch.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,16 @@ def test_merge(merge_op, io_type, backend):
4141
model = MergeModule(merge_op)
4242
model.eval()
4343

44-
batch_input_shape = (None,) + input_shape
4544
config = hls4ml.utils.config_from_pytorch_model(
46-
model, default_precision='ap_fixed<32,16>', channels_last_conversion="internal", transpose_outputs=False
45+
model,
46+
[input_shape, input_shape],
47+
default_precision='ap_fixed<32,16>',
48+
channels_last_conversion="internal",
49+
transpose_outputs=False,
4750
)
4851
output_dir = str(test_root_path / f'hls4mlprj_merge_pytorch_{merge_op}_{backend}_{io_type}')
4952
hls_model = hls4ml.converters.convert_from_pytorch_model(
5053
model,
51-
[batch_input_shape, batch_input_shape],
5254
hls_config=config,
5355
output_dir=output_dir,
5456
io_type=io_type,

0 commit comments

Comments
 (0)