Skip to content

Commit 51cb83c

Browse files
authored
Merge pull request #1112 from JanFSchulte/pytorch_auto
Make auto default precision for pytorch parser
2 parents ef2e8f4 + 6aeafdd commit 51cb83c

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

hls4ml/utils/config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def config_from_pytorch_model(
284284
default_reuse_factor=1,
285285
channels_last_conversion='full',
286286
transpose_outputs=True,
287+
max_precision=None,
287288
):
288289
"""Create an HLS conversion config given the PyTorch model.
289290
@@ -304,7 +305,8 @@ def config_from_pytorch_model(
304305
will generate config keys for every layer separately, allowing for highly specific
305306
configuration tweaks.
306307
backend(str, optional): Name of the backend to use
307-
default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'.
308+
default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. Note, this must
309+
be an explicit precision: 'auto' is not allowed.
308310
default_reuse_factor (int, optional): Default reuse factor. Defaults to 1.
309311
channels_last_conversion (string, optional): Configures the conversion of pytorch layers to
310312
'channels_last' dataformate. Can be set to 'full', 'internal', or 'off'. If 'full', both the inputs
@@ -313,6 +315,8 @@ def config_from_pytorch_model(
313315
transpose_outputs (bool, optional): Set to 'False' if the output should not be transposed from
314316
channels_last into channels_first data format. Defaults to 'False'. If False, outputs needs
315317
to be transposed manually.
318+
max_precision (str or None, optional): Maximum width precision to use. Defaults to None, meaning no maximum.
319+
Note: Only integer and fixed precisions are supported
316320
317321
Raises:
318322
Exception: If PyTorch model has layers not supported by hls4ml.
@@ -324,11 +328,16 @@ def config_from_pytorch_model(
324328
config = {}
325329

326330
model_config = {}
327-
model_config['Precision'] = default_precision
331+
model_config['Precision'] = {}
332+
model_config['Precision']['default'] = default_precision
333+
if max_precision is not None:
334+
model_config['Precision']['maximum'] = max_precision
328335
model_config['ReuseFactor'] = default_reuse_factor
329336
model_config['ChannelsLastConversion'] = channels_last_conversion
330337
model_config['TransposeOutputs'] = transpose_outputs
331338
model_config['Strategy'] = 'Latency'
339+
model_config['BramFactor'] = 1_000_000_000
340+
model_config['TraceOutput'] = False
332341

333342
config['Model'] = model_config
334343
config['PytorchModel'] = model
@@ -372,7 +381,7 @@ def make_layer_config(layer):
372381
if name.endswith('_t'):
373382
name = name[:-2]
374383
if attr.default is None:
375-
precision_cfg[name] = default_precision
384+
precision_cfg[name] = 'auto'
376385
else:
377386
precision_cfg[name] = str(attr.default)
378387
elif attr.name == 'reuse_factor':

0 commit comments

Comments
 (0)