@@ -209,14 +209,18 @@ def __init__(self, graph_id, device_id):
209
209
self .output_dtypes = []
210
210
self .output_datasize = []
211
211
for item in shapes :
212
+ if item == '' :
213
+ self .output_shapes .append ([])
214
+ continue
212
215
elems = item .split (',' )
213
216
elems = [int (x ) for x in elems ]
214
217
self .output_shapes .append (elems )
215
218
for item in dtypes :
216
219
elem = int (item )
217
220
self .output_dtypes .append (elem )
218
221
for i in range (len (shapes )):
219
- elem_size = math .prod (self .output_shapes [i ])
222
+ elem_size = math .prod (self .output_shapes [i ]) if len (
223
+ self .output_shapes [i ]) > 0 else 1
220
224
self .output_datasize .append (
221
225
elem_size * acl .data_type_size (self .output_dtypes [i ]))
222
226
self .output_datasize_c = (
@@ -242,14 +246,18 @@ def __init__(self, graph_id, device_id):
242
246
self .input_datasize = []
243
247
244
248
for item in shapes :
249
+ if item == '' :
250
+ self .input_shapes .append ([])
251
+ continue
245
252
elems = item .split (',' )
246
253
elems = [int (x ) for x in elems ]
247
254
self .input_shapes .append (elems )
248
255
for item in dtypes :
249
256
elem = int (item )
250
257
self .input_dtypes .append (elem )
251
258
for i in range (len (shapes )):
252
- elem_size = math .prod (self .input_shapes [i ])
259
+ elem_size = math .prod (self .input_shapes [i ]) if len (
260
+ self .input_shapes [i ]) > 0 else 1
253
261
self .input_datasize .append (
254
262
elem_size * acl .data_type_size (self .input_dtypes [i ]))
255
263
self .input_datasize_c = (
0 commit comments