Skip to content

Commit bcc8b8d

Browse files
wanglushengkalcohol
wanglusheng
authored andcommitted
fix cffi life cycle
1 parent e28b968 commit bcc8b8d

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

axengine/_axclrt.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def __init__(
9696
super().__init__()
9797

9898
self._device_index = 0
99+
self._io = None
100+
self._model_id = None
99101

100102
if provider_options is not None and "device_id" in provider_options[0]:
101103
self._device_index = provider_options[0].get("device_id", 0)
@@ -214,12 +216,12 @@ def _unload(self):
214216
dev_size = axclrt_cffi.new("uint64_t *")
215217
dev_prt = axclrt_cffi.new("void **")
216218
for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
217-
axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
219+
axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size)
218220
axclrt_lib.axclrtFree(dev_prt[0])
219221
for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
220-
axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
222+
axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size)
221223
axclrt_lib.axclrtFree(dev_prt[0])
222-
axclrt_lib.axclrtEngineDestroyIO(self._io)
224+
axclrt_lib.axclrtEngineDestroyIO(self._io[0])
223225
self._io = None
224226
if self._model_id[0] is not None and self._model_id[0] != 0:
225227
axclrt_lib.axclrtEngineUnload(self._model_id[0])
@@ -322,7 +324,7 @@ def _prepare_io(self):
322324
ret = axclrt_lib.axclrtEngineSetOutputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
323325
if 0 != ret:
324326
raise RuntimeError(f"axclrtEngineSetOutputBufferByIndex failed 0x{ret:08x} for output {i}.")
325-
return _io[0]
327+
return _io
326328

327329
def run(
328330
self,
@@ -353,21 +355,21 @@ def run(
353355
if not (npy.flags.c_contiguous or npy.flags.f_contiguous):
354356
npy = np.ascontiguousarray(npy)
355357
npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
356-
ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
358+
ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size)
357359
if 0 != ret:
358360
raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.")
359361
ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
360362
if 0 != ret:
361363
raise RuntimeError(f"axclrtMemcpy failed for input {i}.")
362364

363365
# execute model
364-
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io)
366+
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io[0])
365367

366368
# get output
367369
outputs = []
368370
if 0 == ret:
369371
for i in range(len(self.get_outputs())):
370-
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
372+
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size)
371373
if 0 != ret:
372374
raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
373375
npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)

0 commit comments

Comments
 (0)