Skip to content

Commit

Permalink
release tensorlist outside aclnn_concat
Browse files Browse the repository at this point in the history
  • Loading branch information
hipudding committed Apr 18, 2024
1 parent 83f5ada commit 6deb7b6
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,9 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ACL_CHECK(aclDestroyTensor(acl_dst));
}

void aclnn_concat(ggml_backend_cann_context& ctx, aclTensor *acl_src0,
aclTensor *acl_src1, aclTensor *acl_dst, int64_t concat_dim,
void aclnn_concat(ggml_backend_cann_context& ctx, aclTensorList* tensorList,
aclTensor* acl_dst, int64_t concat_dim,
ggml_tensor* bind_tensor) {

aclTensor* tensors[] = {acl_src0, acl_src1};
aclTensorList* tensorList = aclCreateTensorList(tensors, 2);

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;
Expand All @@ -157,12 +153,6 @@ void aclnn_concat(ggml_backend_cann_context& ctx, aclTensor *acl_src0,

aclrtStream main_stream = ctx.stream();
ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, main_stream));

//ACL_CHECK(aclDestroyTensor(acl_src0));
//ACL_CHECK(aclDestroyTensor(acl_src1));
ACL_CHECK(aclDestroyTensorList(tensorList));
ACL_CHECK(aclDestroyTensor(acl_dst));

}

void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
Expand All @@ -173,14 +163,11 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
aclTensor* acl_dst = create_acl_tensor(dst);

int64_t concat_dim = 1;
aclTensor* tensors[] = {acl_src0, acl_src1};
aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
aclnn_concat(ctx, tensorList, acl_dst, concat_dim, dst);

aclnn_concat(ctx, acl_src0, acl_src1, acl_dst, concat_dim, dst);

// release acl_src0, acl_src1 in aclnn_concat
// ACL_CHECK(aclDestroyTensor(acl_src0));
// ACL_CHECK(aclDestroyTensor(acl_src1));
// ->
// ACL_CHECK(aclDestroyTensorList(tensorList));
ACL_CHECK(aclDestroyTensorList(tensorList));
ACL_CHECK(aclDestroyTensor(acl_dst));
}

Expand Down Expand Up @@ -1331,7 +1318,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* d
// concat
int64_t concat_dim = 3;
aclTensor* acl_dst = create_acl_tensor(dst);
aclnn_concat(ctx, tmp_cos_tensor, tmp_sin_tensor, acl_dst, concat_dim, dst);
aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
aclnn_concat(ctx, tensorList, acl_dst, concat_dim, dst);

// release
ACL_CHECK(aclDestroyTensor(acl_src));
Expand Down

0 comments on commit 6deb7b6

Please sign in to comment.