@@ -330,7 +330,8 @@ def run(
330
330
self ,
331
331
output_names : list [str ],
332
332
input_feed : dict [str , np .ndarray ],
333
- run_options = None
333
+ run_options = None ,
334
+ shape_group : int = 0
334
335
):
335
336
self ._validate_input (input_feed )
336
337
self ._validate_output (output_names )
@@ -340,13 +341,16 @@ def run(
340
341
raise RuntimeError ("axclrtSetCurrentContext failed" )
341
342
342
343
if None is output_names :
343
- output_names = [o .name for o in self .get_outputs ()]
344
+ output_names = [o .name for o in self .get_outputs (shape_group )]
345
+
346
+ if (shape_group > self ._shape_count - 1 ) or (shape_group < 0 ):
347
+ raise ValueError (f"Invalid shape group: { shape_group } " )
344
348
345
349
# fill model io
346
350
dev_prt = axclrt_cffi .new ("void **" )
347
351
dev_size = axclrt_cffi .new ("uint64_t *" )
348
352
for key , npy in input_feed .items ():
349
- for i , one in enumerate (self .get_inputs ()):
353
+ for i , one in enumerate (self .get_inputs (shape_group )):
350
354
if one .name == key :
351
355
assert (
352
356
list (one .shape ) == list (npy .shape ) and one .dtype == npy .dtype
@@ -363,21 +367,23 @@ def run(
363
367
raise RuntimeError (f"axclrtMemcpy failed for input { i } ." )
364
368
365
369
# execute model
366
- ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], 0 , self ._io [0 ])
370
+ ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], shape_group , self ._io [0 ])
367
371
368
372
# get output
369
373
outputs = []
370
374
if 0 == ret :
371
- for i in range (len (self .get_outputs ())):
375
+ for i in range (len (self .get_outputs (shape_group ))):
372
376
ret = axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io [0 ], i , dev_prt , dev_size )
373
377
if 0 != ret :
374
378
raise RuntimeError (f"axclrtEngineGetOutputBufferByIndex failed for output { i } ." )
375
- npy = np .zeros (self .get_outputs ()[i ].shape , dtype = self .get_outputs ()[i ].dtype )
376
- npy_ptr = axclrt_cffi .cast ("void *" , npy .ctypes .data )
377
- ret = axclrt_lib .axclrtMemcpy (npy_ptr , dev_prt [0 ], npy .nbytes , axclrt_lib .AXCL_MEMCPY_DEVICE_TO_HOST )
378
- if 0 != ret :
379
- raise RuntimeError (f"axclrtMemcpy failed for output { i } ." )
380
- name = self .get_outputs ()[i ].name
379
+ npy_size = self .get_outputs (shape_group )[i ].dtype .itemsize * np .prod (self .get_outputs (shape_group )[i ].shape )
380
+ npy = np .frombuffer (
381
+ axclrt_cffi .buffer (
382
+ self ._io [0 ].pOutputs [i ].pVirAddr , npy_size
383
+ ),
384
+ dtype = self .get_outputs (shape_group )[i ].dtype ,
385
+ ).reshape (self .get_outputs (shape_group )[i ].shape )
386
+ name = self .get_outputs (shape_group )[i ].name
381
387
if name in output_names :
382
388
outputs .append (npy )
383
389
return outputs
0 commit comments