diff --git a/ggml-cann.cpp b/ggml-cann.cpp index 36685d2bb21818..956ddeb6f6fc58 100644 --- a/ggml-cann.cpp +++ b/ggml-cann.cpp @@ -450,6 +450,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, break; case GGML_OP_ROPE: case GGML_OP_ALIBI: + ggml_cann_alibi(ctx, dst); + break; case GGML_OP_IM2COL: ggml_cann_im2col(ctx, dst); break; @@ -685,8 +687,9 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, return true; case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: return false; + case GGML_OP_ALIBI: + return true; case GGML_OP_IM2COL: return true; case GGML_OP_POOL_2D: diff --git a/ggml-cann/aclnn_ops.cpp b/ggml-cann/aclnn_ops.cpp index a2d536524c652c..b547a89f9d3e2f 100644 --- a/ggml-cann/aclnn_ops.cpp +++ b/ggml-cann/aclnn_ops.cpp @@ -6,10 +6,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -60,6 +62,30 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } +void aclnn_add(ggml_backend_cann_context& ctx, aclTensor *acl_src0, + aclTensor *acl_src1, aclTensor *acl_dst, + ggml_tensor* bind_tensor) { + + aclScalar* alpha = nullptr; + float alphaValue = 1.0f; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst, + &workspaceSize, &executor)); + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize); + } + + aclrtStream main_stream = ctx.stream(); + ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, main_stream)); + + ACL_CHECK(aclDestroyScalar(alpha)); +} + void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; ggml_tensor* src1 = dst->src[1]; @@ -81,24 +107,8 @@ void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) { acl_dst = create_acl_tensor(dst); } - aclScalar* alpha = nullptr; - float alphaValue = 1.0f; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; + aclnn_add(ctx, acl_src0, acl_src1, acl_dst, dst); - ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - workspaceAddr = ctx.alloc_buffer(dst, workspaceSize); - } - - aclrtStream main_stream = ctx.stream(); - ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, main_stream)); - - ACL_CHECK(aclDestroyScalar(alpha)); ACL_CHECK(aclDestroyTensor(acl_src0)); ACL_CHECK(aclDestroyTensor(acl_src1)); ACL_CHECK(aclDestroyTensor(acl_dst)); @@ -1158,7 +1168,8 @@ void aclnn_inplace_mul(ggml_backend_cann_context& ctx, aclTensor *acl_src, } void aclnn_noinplcace_mul(ggml_backend_cann_context& ctx, aclTensor *acl_src, - aclTensor *acl_other, aclTensor *acl_dst, ggml_tensor* bind_tensor) { + aclTensor *acl_other, aclTensor *acl_dst, + ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; aclOpExecutor* executor; @@ -1208,7 +1219,8 @@ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor *acl_src, ctx.stream())); } -void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, + ggml_tensor* dst) { const ggml_tensor* src = dst->src[0]; GGML_ASSERT(src->type == GGML_TYPE_F32); @@ -1253,7 +1265,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* d tmp_permute_nb[i] = tmp_permute_nb[i-1] * tmp_permute_ne[i-1]; } - void* tmp_permute_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src)*320); + void* tmp_permute_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src)); aclTensor* tmp_permute_tenosr = create_acl_tensor(tmp_permute_buffer, type_mapping(src->type), ggml_type_size(src->type), @@ -1323,12 +1335,252 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* d aclnn_concat(ctx, tensorList, acl_dst, concat_dim, dst); // release + // segmentation fault when delete both tensorList and his elements. ACL_CHECK(aclDestroyTensorList(tensorList)); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(tmp_arange_tensor)); ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr)); ACL_CHECK(aclDestroyTensor(tmp_mul_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_cos_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_sin_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst)); +} + +void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, + aclTensor *acl_dst, ggml_tensor* bind_tensor) { + + auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT); + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + ACL_CHECK(aclnnInplaceFillScalarGetWorkspaceSize(acl_dst, acl_scalar, + &workspaceSize, + &executor)); + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize); + } + + ACL_CHECK(aclnnInplaceFillScalar(workspaceAddr, workspaceSize, executor, + ctx.stream())); + ACL_CHECK(aclDestroyScalar(acl_scalar)); +} + +void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,aclTensor *acl_dst, + aclTensor *acl_exp, ggml_tensor* bind_tensor) { + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + ACL_CHECK(aclnnInplacePowTensorTensorGetWorkspaceSize(acl_dst, acl_exp, + &workspaceSize, + &executor)); + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize); + } + + ACL_CHECK(aclnnInplacePowTensorTensor(workspaceAddr, workspaceSize, + executor, ctx.stream())); +} + +void ggml_cann_alibi(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src = dst->src[0]; + + const int n_head = ((int32_t*) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t*) dst->op_params + 2, sizeof(float)); + + const int64_t ne0 = src->ne[0]; // all_seq_len = n_past + ne1 + const int64_t ne1 = src->ne[1]; // seq_len_without_past + const int64_t ne2 = src->ne[2]; // n_head -> this is k + const int64_t ne3 = src->ne[3]; // batch + + const int64_t n = ggml_nrows(src); + const int64_t ne2_ne3 = n/ne1; // ne2*ne3 + + const size_t nb0 = src->nb[0]; + + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(n_head == ne2); + + // add alibi to src (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + // init arange + void* tmp_arange_buffer = ctx.alloc_buffer(dst, ne2_ne3 * + ggml_type_size(dst->type)); + size_t memset_size = ne2_ne3 * ggml_type_size(dst->type); + ACL_CHECK(aclrtMemset(tmp_arange_buffer, memset_size, 0, memset_size)); + + // arange1: [1, ..., n_heads_log2_floor+1) + float start = 1; + float stop = n_heads_log2_floor + 1; + float step = 1; + int64_t n_elements_arange = n_heads_log2_floor; + + int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; + size_t tmp_arange1_nb[] = {sizeof(dst->type)}; + aclTensor* tmp_arange1_tensor = create_acl_tensor(tmp_arange_buffer, + type_mapping(dst->type), + ggml_type_size(dst->type), + tmp_arange1_ne, + tmp_arange1_nb, + GGML_MAX_DIMS-3, + ACL_FORMAT_ND); + + aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange, + dst); + + aclTensor* tmp_arange2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) + start = 1; + stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; + step = 2; + n_elements_arange = ne2_ne3 - n_heads_log2_floor; + int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_arange2_nb[] = {sizeof(dst->type)}; + + aclTensor* tmp_arange2_tensor = create_acl_tensor(tmp_arange_buffer + + n_heads_log2_floor * + ggml_type_size( + dst->type), + type_mapping( + dst->type), + ggml_type_size( + dst->type), + tmp_arange2_ne, + tmp_arange2_nb, + GGML_MAX_DIMS-3, + ACL_FORMAT_ND); + aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, + n_elements_arange, dst); + } + + // init mk_base + void* tmp_mk_base_buffer = ctx.alloc_buffer(dst, ne2_ne3 * + ggml_type_size(dst->type)); + int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; + size_t tmp_mk_base1_nb[] = {sizeof(dst->type)}; + aclTensor* tmp_mk_base1_tensor = create_acl_tensor(tmp_mk_base_buffer, + type_mapping(dst->type), + ggml_type_size( + dst->type), + tmp_mk_base1_ne, + tmp_mk_base1_nb, + GGML_MAX_DIMS-3, + ACL_FORMAT_ND); + + aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor, dst); + + aclTensor* tmp_mk_base2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_mk_base2_nb[] = {sizeof(dst->type)}; + aclTensor* tmp_mk_base2_tensor = create_acl_tensor(tmp_mk_base_buffer + + n_heads_log2_floor * + ggml_type_size( + dst->type), + type_mapping( + dst->type), + ggml_type_size( + dst->type), + tmp_mk_base2_ne, + tmp_mk_base2_nb, + GGML_MAX_DIMS-3, + ACL_FORMAT_ND); + aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor, dst); + } + + // init mk + int64_t tmp_mk_base_ne[] = {ne2_ne3}; + size_t tmp_mk_base_nb[] = {sizeof(dst->type)}; + aclTensor* tmp_mk_base_tensor = create_acl_tensor(tmp_mk_base_buffer, + type_mapping(dst->type), + ggml_type_size(dst->type), + tmp_mk_base_ne, + tmp_mk_base_nb, + GGML_MAX_DIMS-3, + ACL_FORMAT_ND); + aclTensor* tmp_arange_tensor = create_acl_tensor(tmp_arange_buffer, + type_mapping(dst->type), + ggml_type_size(dst->type), + tmp_mk_base_ne, + tmp_mk_base_nb, + GGML_MAX_DIMS-3, + ACL_FORMAT_ND); + aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor, dst); + + // reshape mk + int64_t tmp_mk_ne[] = {1, 1, ne2, ne3}; + size_t tmp_mk_nb[GGML_MAX_DIMS]; + tmp_mk_nb[0] = ggml_type_size(dst->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_mk_nb[i] = tmp_mk_nb[i-1] * tmp_mk_ne[i-1]; + } + aclTensor* tmp_mk_tensor = create_acl_tensor(tmp_mk_base_buffer, + type_mapping(dst->type), + ggml_type_size(dst->type), + tmp_mk_ne, + tmp_mk_nb, + GGML_MAX_DIMS, + ACL_FORMAT_ND); + + // arange: [0, ..., ne0) + start = 0; + stop = ne0; + step = 1; + n_elements_arange = ne0; + int64_t tmp_arange3_ne[] = {ne0, 1, 1, 1}; + size_t tmp_arange3_nb[] = {sizeof(dst->type)}; + + void* tmp_arange3_buffer = ctx.alloc_buffer(dst, ne0 * sizeof(dst->type)); + aclTensor* tmp_arange3_tensor = create_acl_tensor(tmp_arange3_buffer, + type_mapping(dst->type), + ggml_type_size(dst->type), + tmp_arange3_ne, + tmp_arange3_nb, + GGML_MAX_DIMS, + ACL_FORMAT_ND); + + aclnn_arange(ctx, tmp_arange3_tensor, start, stop, step, n_elements_arange, + dst); + + // arange3 * mk + int64_t tmp_output_ne[] = {ne0, 1, ne2, ne3}; + size_t tmp_output_nb[GGML_MAX_DIMS]; + tmp_output_nb[0] = ggml_type_size(dst->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_output_nb[i] = tmp_output_nb[i-1] * tmp_output_ne[i-1]; + } + void* tmp_output_buffer = ctx.alloc_buffer(dst, ggml_nbytes(dst)); + aclTensor* tmp_output_tensor = create_acl_tensor(tmp_output_buffer, + type_mapping(dst->type), + ggml_type_size(dst->type), + tmp_output_ne, + tmp_output_nb, + GGML_MAX_DIMS, + ACL_FORMAT_ND); + aclnn_noinplcace_mul(ctx, tmp_arange3_tensor, tmp_mk_tensor, + tmp_output_tensor, dst); + + // add + aclTensor* acl_src = create_acl_tensor(src); + aclTensor* acl_dst = create_acl_tensor(dst); + aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst, dst); + + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_arange_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_mk_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_arange3_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_output_tensor)); } \ No newline at end of file diff --git a/ggml-cann/aclnn_ops.h b/ggml-cann/aclnn_ops.h index b1ae0da35604a5..7b632eb90f2595 100644 --- a/ggml-cann/aclnn_ops.h +++ b/ggml-cann/aclnn_ops.h @@ -64,6 +64,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst); void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_alibi(ggml_backend_cann_context& ctx, ggml_tensor* dst); + template diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 51b3487b2a948c..b47dd84dada386 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1473,6 +1473,31 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_ALIBI +struct test_alibi : public test_case { + const ggml_type type; + const std::array ne_a; + const int n_past; + const int n_head; + const float bias_max; + + std::string vars() override { + return VARS_TO_STR5(type, ne_a, n_past, n_head, bias_max); + } + + test_alibi(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {30, 20, 10, 1}, + int n_past = 0, int n_head = 10, + float bias_max = 0.9f) + : type(type), ne_a(ne_a), n_past(n_past), n_head(n_head), bias_max(bias_max) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * out = ggml_alibi(ctx, a, n_past, n_head, bias_max); + return out; + } +}; + enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -2092,6 +2117,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); + for (float bias_max : {-0.5, 0.5}) { + test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {16, 2, 10, 1}, 0, 10, bias_max)); + test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {16, 2, 32, 1}, 0, 32, bias_max)); + test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {128, 4, 10, 1}, 0, 10, bias_max)); + test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {128, 4, 32, 1}, 0, 32, bias_max)); + } + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 test_cases.emplace_back(new test_llama(1));