@@ -102,7 +102,7 @@ def decorator(function):
102
102
# ----------------------------------------------------------------
103
103
104
104
105
- def pytorch_to_hls (config ):
105
+ def parse_pytorch_model (config , verbose = True ):
106
106
"""Convert PyTorch model to hls4ml ModelGraph.
107
107
108
108
Args:
@@ -118,14 +118,15 @@ def pytorch_to_hls(config):
118
118
# This is a list of dictionaries to hold all the layer info we need to generate HLS
119
119
layer_list = []
120
120
121
- print ( 'Interpreting Model ...' )
122
-
121
+ if verbose :
122
+ print ( 'Interpreting Model ...' )
123
123
reader = PyTorchFileReader (config ) if isinstance (config ['PytorchModel' ], str ) else PyTorchModelReader (config )
124
124
if type (reader .input_shape ) is tuple :
125
125
input_shapes = [list (reader .input_shape )]
126
126
else :
127
127
input_shapes = list (reader .input_shape )
128
- input_shapes = [list (shape ) for shape in input_shapes ]
128
+ # first element needs to 'None' as placeholder for the batch size, insert it if not present
129
+ input_shapes = [[None ] + list (shape ) if shape [0 ] is not None else list (shape ) for shape in input_shapes ]
129
130
130
131
model = reader .torch_model
131
132
@@ -151,7 +152,8 @@ def pytorch_to_hls(config):
151
152
output_shape = None
152
153
153
154
# Loop through layers
154
- print ('Topology:' )
155
+ if verbose :
156
+ print ('Topology:' )
155
157
layer_counter = 0
156
158
157
159
n_inputs = 0
@@ -226,13 +228,14 @@ def pytorch_to_hls(config):
226
228
pytorch_class , layer_name , input_names , input_shapes , node , class_object , reader , config
227
229
)
228
230
229
- print (
230
- 'Layer name: {}, layer type: {}, input shape: {}' .format (
231
- layer ['name' ],
232
- layer ['class_name' ],
233
- input_shapes ,
231
+ if verbose :
232
+ print (
233
+ 'Layer name: {}, layer type: {}, input shape: {}' .format (
234
+ layer ['name' ],
235
+ layer ['class_name' ],
236
+ input_shapes ,
237
+ )
234
238
)
235
- )
236
239
layer_list .append (layer )
237
240
238
241
assert output_shape is not None
@@ -288,7 +291,12 @@ def pytorch_to_hls(config):
288
291
operation , layer_name , input_names , input_shapes , node , None , reader , config
289
292
)
290
293
291
- print ('Layer name: {}, layer type: {}, input shape: {}' .format (layer ['name' ], layer ['class_name' ], input_shapes ))
294
+ if verbose :
295
+ print (
296
+ 'Layer name: {}, layer type: {}, input shape: {}' .format (
297
+ layer ['name' ], layer ['class_name' ], input_shapes
298
+ )
299
+ )
292
300
layer_list .append (layer )
293
301
294
302
assert output_shape is not None
@@ -342,7 +350,12 @@ def pytorch_to_hls(config):
342
350
operation , layer_name , input_names , input_shapes , node , None , reader , config
343
351
)
344
352
345
- print ('Layer name: {}, layer type: {}, input shape: {}' .format (layer ['name' ], layer ['class_name' ], input_shapes ))
353
+ if verbose :
354
+ print (
355
+ 'Layer name: {}, layer type: {}, input shape: {}' .format (
356
+ layer ['name' ], layer ['class_name' ], input_shapes
357
+ )
358
+ )
346
359
layer_list .append (layer )
347
360
348
361
assert output_shape is not None
@@ -351,6 +364,11 @@ def pytorch_to_hls(config):
351
364
if len (input_layers ) == 0 :
352
365
input_layers = None
353
366
367
+ return layer_list , input_layers
368
+
369
+
370
+ def pytorch_to_hls (config ):
371
+ layer_list , input_layers = parse_pytorch_model (config )
354
372
print ('Creating HLS model' )
355
373
hls_model = ModelGraph (config , layer_list , inputs = input_layers )
356
374
return hls_model
0 commit comments