Skip to content

Commit 6bc2ac8

Browse files
committed
add shape group support
1 parent 2f27ca9 commit 6bc2ac8

File tree

3 files changed

+41
-24
lines changed

3 files changed

+41
-24
lines changed

axengine/_axclrt.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ def run(
330330
self,
331331
output_names: list[str],
332332
input_feed: dict[str, np.ndarray],
333-
run_options=None
333+
run_options=None,
334+
shape_group: int = 0
334335
):
335336
self._validate_input(input_feed)
336337
self._validate_output(output_names)
@@ -340,13 +341,16 @@ def run(
340341
raise RuntimeError("axclrtSetCurrentContext failed")
341342

342343
if None is output_names:
343-
output_names = [o.name for o in self.get_outputs()]
344+
output_names = [o.name for o in self.get_outputs(shape_group)]
345+
346+
if (shape_group > self._shape_count - 1) or (shape_group < 0):
347+
raise ValueError(f"Invalid shape group: {shape_group}")
344348

345349
# fill model io
346350
dev_prt = axclrt_cffi.new("void **")
347351
dev_size = axclrt_cffi.new("uint64_t *")
348352
for key, npy in input_feed.items():
349-
for i, one in enumerate(self.get_inputs()):
353+
for i, one in enumerate(self.get_inputs(shape_group)):
350354
if one.name == key:
351355
assert (
352356
list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
@@ -363,21 +367,23 @@ def run(
363367
raise RuntimeError(f"axclrtMemcpy failed for input {i}.")
364368

365369
# execute model
366-
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io[0])
370+
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], shape_group, self._io[0])
367371

368372
# get output
369373
outputs = []
370374
if 0 == ret:
371-
for i in range(len(self.get_outputs())):
375+
for i in range(len(self.get_outputs(shape_group))):
372376
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size)
373377
if 0 != ret:
374378
raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
375-
npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)
376-
npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
377-
ret = axclrt_lib.axclrtMemcpy(npy_ptr, dev_prt[0], npy.nbytes, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST)
378-
if 0 != ret:
379-
raise RuntimeError(f"axclrtMemcpy failed for output {i}.")
380-
name = self.get_outputs()[i].name
379+
npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape)
380+
npy = np.frombuffer(
381+
axclrt_cffi.buffer(
382+
self._io[0].pOutputs[i].pVirAddr, npy_size
383+
),
384+
dtype=self.get_outputs(shape_group)[i].dtype,
385+
).reshape(self.get_outputs(shape_group)[i].shape)
386+
name = self.get_outputs(shape_group)[i].name
381387
if name in output_names:
382388
outputs.append(npy)
383389
return outputs

axengine/_axe.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -346,17 +346,21 @@ def run(
346346
self,
347347
output_names: list[str],
348348
input_feed: dict[str, np.ndarray],
349-
run_options=None
349+
run_options=None,
350+
shape_group: int = 0
350351
):
351352
self._validate_input(input_feed)
352353
self._validate_output(output_names)
353354

354355
if None is output_names:
355-
output_names = [o.name for o in self.get_outputs()]
356+
output_names = [o.name for o in self.get_outputs(shape_group)]
357+
358+
if (shape_group > self._shape_count - 1) or (shape_group < 0):
359+
raise ValueError(f"Invalid shape group: {shape_group}")
356360

357361
# fill model io
358362
for key, npy in input_feed.items():
359-
for i, one in enumerate(self.get_inputs()):
363+
for i, one in enumerate(self.get_inputs(shape_group)):
360364
if one.name == key:
361365
assert (
362366
list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
@@ -377,26 +381,32 @@ def run(
377381
break
378382

379383
# execute model
380-
ret = engine_lib.AX_ENGINE_RunSyncV2(
381-
self._handle[0], self._context[0], self._io
382-
)
384+
if self._shape_count > 1:
385+
ret = engine_lib.AX_ENGINE_RunGroupIOSync(
386+
self._handle[0], self._context[0], shape_group, self._io
387+
)
388+
else:
389+
ret = engine_lib.AX_ENGINE_RunSyncV2(
390+
self._handle[0], self._context[0], self._io
391+
)
383392

384393
# flush output
385394
outputs = []
386395
if 0 == ret:
387-
for i in range(len(self.get_outputs())):
396+
for i in range(len(self.get_outputs(shape_group))):
388397
sys_lib.AX_SYS_MinvalidateCache(
389398
self._io[0].pOutputs[i].phyAddr,
390399
self._io[0].pOutputs[i].pVirAddr,
391400
self._io[0].pOutputs[i].nSize,
392401
)
402+
npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape)
393403
npy = np.frombuffer(
394404
engine_cffi.buffer(
395-
self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize
405+
self._io[0].pOutputs[i].pVirAddr, npy_size
396406
),
397-
dtype=self.get_outputs()[i].dtype,
398-
).reshape(self.get_outputs()[i].shape)
399-
name = self.get_outputs()[i].name
407+
dtype=self.get_outputs(shape_group)[i].dtype,
408+
).reshape(self.get_outputs(shape_group)[i].shape)
409+
name = self.get_outputs(shape_group)[i].name
400410
if name in output_names:
401411
outputs.append(npy)
402412
return outputs

axengine/_session.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def run(
112112
self,
113113
output_names: list[str] | None,
114114
input_feed: dict[str, np.ndarray],
115-
run_options=None
115+
run_options=None,
116+
shape_group: int = 0
116117
) -> list[np.ndarray]:
117-
return self._sess.run(output_names, input_feed, run_options)
118+
return self._sess.run(output_names, input_feed, run_options, shape_group)

0 commit comments

Comments
 (0)