@@ -96,6 +96,8 @@ def __init__(
96
96
super ().__init__ ()
97
97
98
98
self ._device_index = 0
99
+ self ._io = None
100
+ self ._model_id = None
99
101
100
102
if provider_options is not None and "device_id" in provider_options [0 ]:
101
103
self ._device_index = provider_options [0 ].get ("device_id" , 0 )
@@ -214,12 +216,12 @@ def _unload(self):
214
216
dev_size = axclrt_cffi .new ("uint64_t *" )
215
217
dev_prt = axclrt_cffi .new ("void **" )
216
218
for i in range (axclrt_lib .axclrtEngineGetNumInputs (self ._info [0 ])):
217
- axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io , i , dev_prt , dev_size )
219
+ axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
218
220
axclrt_lib .axclrtFree (dev_prt [0 ])
219
221
for i in range (axclrt_lib .axclrtEngineGetNumOutputs (self ._info [0 ])):
220
- axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io , i , dev_prt , dev_size )
222
+ axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
221
223
axclrt_lib .axclrtFree (dev_prt [0 ])
222
- axclrt_lib .axclrtEngineDestroyIO (self ._io )
224
+ axclrt_lib .axclrtEngineDestroyIO (self ._io [ 0 ] )
223
225
self ._io = None
224
226
if self ._model_id [0 ] is not None and self ._model_id [0 ] != 0 :
225
227
axclrt_lib .axclrtEngineUnload (self ._model_id [0 ])
@@ -322,7 +324,7 @@ def _prepare_io(self):
322
324
ret = axclrt_lib .axclrtEngineSetOutputBufferByIndex (_io [0 ], i , dev_ptr [0 ], max_size )
323
325
if 0 != ret :
324
326
raise RuntimeError (f"axclrtEngineSetOutputBufferByIndex failed 0x{ ret :08x} for output { i } ." )
325
- return _io [ 0 ]
327
+ return _io
326
328
327
329
def run (
328
330
self ,
@@ -353,21 +355,21 @@ def run(
353
355
if not (npy .flags .c_contiguous or npy .flags .f_contiguous ):
354
356
npy = np .ascontiguousarray (npy )
355
357
npy_ptr = axclrt_cffi .cast ("void *" , npy .ctypes .data )
356
- ret = axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io , i , dev_prt , dev_size )
358
+ ret = axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
357
359
if 0 != ret :
358
360
raise RuntimeError (f"axclrtEngineGetInputBufferByIndex failed for input { i } ." )
359
361
ret = axclrt_lib .axclrtMemcpy (dev_prt [0 ], npy_ptr , npy .nbytes , axclrt_lib .AXCL_MEMCPY_HOST_TO_DEVICE )
360
362
if 0 != ret :
361
363
raise RuntimeError (f"axclrtMemcpy failed for input { i } ." )
362
364
363
365
# execute model
364
- ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], 0 , self ._io )
366
+ ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], 0 , self ._io [ 0 ] )
365
367
366
368
# get output
367
369
outputs = []
368
370
if 0 == ret :
369
371
for i in range (len (self .get_outputs ())):
370
- ret = axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io , i , dev_prt , dev_size )
372
+ ret = axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
371
373
if 0 != ret :
372
374
raise RuntimeError (f"axclrtEngineGetOutputBufferByIndex failed for output { i } ." )
373
375
npy = np .zeros (self .get_outputs ()[i ].shape , dtype = self .get_outputs ()[i ].dtype )
0 commit comments