diff --git a/ggml-cann/aclnn_ops.cpp b/ggml-cann/aclnn_ops.cpp index d6880f2250737f..4256ba779eb53a 100644 --- a/ggml-cann/aclnn_ops.cpp +++ b/ggml-cann/aclnn_ops.cpp @@ -137,13 +137,9 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } -void aclnn_concat(ggml_backend_cann_context& ctx, aclTensor *acl_src0, - aclTensor *acl_src1, aclTensor *acl_dst, int64_t concat_dim, +void aclnn_concat(ggml_backend_cann_context& ctx, aclTensorList* tensorList, + aclTensor* acl_dst, int64_t concat_dim, ggml_tensor* bind_tensor) { - - aclTensor* tensors[] = {acl_src0, acl_src1}; - aclTensorList* tensorList = aclCreateTensorList(tensors, 2); - uint64_t workspaceSize = 0; aclOpExecutor* executor; void* workspaceAddr = nullptr; @@ -157,12 +153,6 @@ void aclnn_concat(ggml_backend_cann_context& ctx, aclTensor *acl_src0, aclrtStream main_stream = ctx.stream(); ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, main_stream)); - - //ACL_CHECK(aclDestroyTensor(acl_src0)); - //ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensorList(tensorList)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - } void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -173,14 +163,11 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_dst = create_acl_tensor(dst); int64_t concat_dim = 1; + aclTensor* tensors[] = {acl_src0, acl_src1}; + aclTensorList* tensorList = aclCreateTensorList(tensors, 2); + aclnn_concat(ctx, tensorList, acl_dst, concat_dim, dst); - aclnn_concat(ctx, acl_src0, acl_src1, acl_dst, concat_dim, dst); - - // release acl_src0, acl_src1 in aclnn_concat - // ACL_CHECK(aclDestroyTensor(acl_src0)); - // ACL_CHECK(aclDestroyTensor(acl_src1)); - // -> - // ACL_CHECK(aclDestroyTensorList(tensorList)); + ACL_CHECK(aclDestroyTensorList(tensorList)); ACL_CHECK(aclDestroyTensor(acl_dst)); } @@ -1331,7 +1318,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* d // concat int64_t concat_dim = 3; aclTensor* acl_dst = create_acl_tensor(dst); - aclnn_concat(ctx, tmp_cos_tensor, tmp_sin_tensor, acl_dst, concat_dim, dst); + aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor}; + aclTensorList* tensorList = aclCreateTensorList(tensors, 2); + aclnn_concat(ctx, tensorList, acl_dst, concat_dim, dst); // release ACL_CHECK(aclDestroyTensor(acl_src));