Skip to content

Commit ec68e84

Browse files
committed
ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (#14435)
ggml-ci
1 parent 307e79d commit ec68e84

File tree

11 files changed

+247
-153
lines changed

11 files changed

+247
-153
lines changed

ggml/include/ggml.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,8 +1510,14 @@ extern "C" {
15101510
struct ggml_context * ctx,
15111511
struct ggml_tensor * a);
15121512

1513+
// a [ne0, ne01, ne02, ne03]
1514+
// mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
1515+
//
1516+
// broadcast:
1517+
// ne02 % ne12 == 0
1518+
// ne03 % ne13 == 0
1519+
//
15131520
// fused soft_max(a*scale + mask*(ALiBi slope))
1514-
// mask is optional
15151521
// max_bias = 0.0f for no ALiBi
15161522
GGML_API struct ggml_tensor * ggml_soft_max_ext(
15171523
struct ggml_context * ctx,
@@ -1974,11 +1980,16 @@ extern "C" {
19741980

19751981
#define GGML_KQ_MASK_PAD 64
19761982

1977-
// q: [n_embd_k, n_batch, n_head, 1]
1978-
// k: [n_embd_k, n_kv, n_head_kv, 1]
1979-
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1980-
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1981-
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
1983+
// q: [n_embd_k, n_batch, n_head, ne3]
1984+
// k: [n_embd_k, n_kv, n_head_kv, ne3]
1985+
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
1986+
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1987+
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
1988+
//
1989+
// broadcast:
1990+
// n_head % n_head_kv == 0
1991+
// ne3 % ne32 == 0
1992+
//
19821993
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
19831994
struct ggml_context * ctx,
19841995
struct ggml_tensor * q,

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2187,7 +2187,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21872187
case GGML_OP_SQRT:
21882188
case GGML_OP_CLAMP:
21892189
case GGML_OP_DIAG_MASK_INF:
2190-
case GGML_OP_SOFT_MAX:
21912190
case GGML_OP_SUM_ROWS:
21922191
case GGML_OP_ARGSORT:
21932192
case GGML_OP_ACC:
@@ -2205,6 +2204,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
22052204
case GGML_OP_PAD_REFLECT_1D:
22062205
case GGML_OP_COUNT_EQUAL:
22072206
return true;
2207+
case GGML_OP_SOFT_MAX:
2208+
// TODO: support broadcast
2209+
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
2210+
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
22082211
case GGML_OP_FLASH_ATTN_EXT:{
22092212
// derived from [ggml-cuda.cu]
22102213
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2227,6 +2230,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
22272230
// DeepSeek MLA
22282231
return false;
22292232
}
2233+
// TODO: support broadcast
2234+
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
22302235
if (op->src[0]->ne[3] != 1) {
22312236
return false;
22322237
}

ggml/src/ggml-cpu/ops.cpp

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5232,14 +5232,17 @@ static void ggml_compute_forward_soft_max_f32(
52325232
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
52335233
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
52345234

5235-
// TODO: handle transposed/permuted matrices
5236-
52375235
const int ith = params->ith;
52385236
const int nth = params->nth;
52395237

52405238
GGML_TENSOR_UNARY_OP_LOCALS
52415239

5242-
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
5240+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
5241+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
5242+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
5243+
5244+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
5245+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
52435246

52445247
// TODO: is this supposed to be ceil instead of floor?
52455248
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5252,66 @@ static void ggml_compute_forward_soft_max_f32(
52495252
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
52505253
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
52515254

5252-
const int nc = src0->ne[0];
5253-
const int nr = ggml_nrows(src0);
5254-
5255-
// rows per thread
5256-
const int dr = (nr + nth - 1)/nth;
5257-
5258-
// row range for this thread
5259-
const int ir0 = dr*ith;
5260-
const int ir1 = MIN(ir0 + dr, nr);
5261-
5262-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5255+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
52635256

52645257
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
52655258

5266-
for (int i1 = ir0; i1 < ir1; i1++) {
5267-
// ALiBi
5268-
const uint32_t h = (i1/ne01)%ne02; // head
5269-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5270-
5271-
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
5272-
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
5273-
5274-
// broadcast the mask across rows
5275-
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5276-
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5277-
5278-
ggml_vec_cpy_f32 (nc, wp, sp);
5279-
ggml_vec_scale_f32(nc, wp, scale);
5280-
if (mp_f32) {
5281-
if (use_f16) {
5282-
for (int i = 0; i < nc; ++i) {
5283-
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5284-
}
5285-
} else {
5286-
for (int i = 0; i < nc; ++i) {
5287-
wp[i] += slope*mp_f32[i];
5259+
for (int64_t i03 = 0; i03 < ne03; i03++) {
5260+
for (int64_t i02 = 0; i02 < ne02; i02++) {
5261+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5262+
const int64_t i11 = i01;
5263+
const int64_t i12 = i02%ne12;
5264+
const int64_t i13 = i03%ne13;
5265+
5266+
// ALiBi
5267+
const uint32_t h = i02; // head
5268+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5269+
5270+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5271+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5272+
5273+
// broadcast the mask across rows
5274+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5275+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5276+
5277+
ggml_vec_cpy_f32 (ne00, wp, sp);
5278+
ggml_vec_scale_f32(ne00, wp, scale);
5279+
if (mp_f32) {
5280+
if (use_f16) {
5281+
for (int i = 0; i < ne00; ++i) {
5282+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5283+
}
5284+
} else {
5285+
for (int i = 0; i < ne00; ++i) {
5286+
wp[i] += slope*mp_f32[i];
5287+
}
5288+
}
52885289
}
5289-
}
5290-
}
52915290

52925291
#ifndef NDEBUG
5293-
for (int i = 0; i < nc; ++i) {
5294-
//printf("p[%d] = %f\n", i, p[i]);
5295-
assert(!isnan(wp[i]));
5296-
}
5292+
for (int i = 0; i < ne00; ++i) {
5293+
//printf("p[%d] = %f\n", i, p[i]);
5294+
assert(!isnan(wp[i]));
5295+
}
52975296
#endif
52985297

5299-
float max = -INFINITY;
5300-
ggml_vec_max_f32(nc, &max, wp);
5298+
float max = -INFINITY;
5299+
ggml_vec_max_f32(ne00, &max, wp);
53015300

5302-
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
5303-
assert(sum > 0.0);
5301+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5302+
assert(sum > 0.0);
53045303

5305-
sum = 1.0/sum;
5306-
ggml_vec_scale_f32(nc, dp, sum);
5304+
sum = 1.0/sum;
5305+
ggml_vec_scale_f32(ne00, dp, sum);
53075306

53085307
#ifndef NDEBUG
5309-
for (int i = 0; i < nc; ++i) {
5310-
assert(!isnan(dp[i]));
5311-
assert(!isinf(dp[i]));
5312-
}
5308+
for (int i = 0; i < ne00; ++i) {
5309+
assert(!isnan(dp[i]));
5310+
assert(!isinf(dp[i]));
5311+
}
53135312
#endif
5313+
}
5314+
}
53145315
}
53155316
}
53165317

@@ -7766,7 +7767,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77667767
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
77677768
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
77687769

7769-
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7770+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
77707771
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
77717772
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
77727773
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7798,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77987799
memset(VKQ32, 0, DV*sizeof(float));
77997800
}
78007801

7801-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
7802+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
78027803

78037804
// k indices
78047805
const int ik3 = iq3 / rk3;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3327,8 +3327,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33273327
case GGML_OP_CONT:
33283328
return op->src[0]->type != GGML_TYPE_BF16;
33293329
case GGML_OP_DIAG_MASK_INF:
3330-
case GGML_OP_SOFT_MAX:
33313330
return true;
3331+
case GGML_OP_SOFT_MAX:
3332+
// TODO: support batching
3333+
if (op->src[0]->ne[3] != 1) {
3334+
return false;
3335+
}
3336+
// TODO: support broadcast
3337+
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
3338+
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
33323339
case GGML_OP_SOFT_MAX_BACK: {
33333340
float max_bias = 0.0f;
33343341
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
@@ -3375,6 +3382,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33753382
if (op->src[0]->ne[0] == 192) {
33763383
return false;
33773384
}
3385+
// TODO: support broadcast
3386+
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
33783387
if (op->src[0]->ne[3] != 1) {
33793388
return false;
33803389
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ typedef struct {
229229
uint64_t nb21;
230230
uint64_t nb22;
231231
uint64_t nb23;
232+
int32_t ne32;
232233
uint64_t nb31;
234+
uint64_t nb32;
233235
int32_t ne1;
234236
int32_t ne2;
235237
float scale;
@@ -461,9 +463,21 @@ typedef struct {
461463
} ggml_metal_kargs_sum_rows;
462464

463465
typedef struct {
464-
int64_t ne00;
465-
int64_t ne01;
466-
int64_t ne02;
466+
int32_t ne00;
467+
int32_t ne01;
468+
int32_t ne02;
469+
uint64_t nb01;
470+
uint64_t nb02;
471+
uint64_t nb03;
472+
int32_t ne11;
473+
int32_t ne12;
474+
int32_t ne13;
475+
uint64_t nb11;
476+
uint64_t nb12;
477+
uint64_t nb13;
478+
uint64_t nb1;
479+
uint64_t nb2;
480+
uint64_t nb3;
467481
float scale;
468482
float max_bias;
469483
float m0;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,7 +1725,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
17251725
case GGML_OP_MEAN:
17261726
case GGML_OP_SOFT_MAX:
17271727
case GGML_OP_GROUP_NORM:
1728-
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1728+
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
17291729
case GGML_OP_RMS_NORM:
17301730
case GGML_OP_L2_NORM:
17311731
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@@ -2644,10 +2644,7 @@ static bool ggml_metal_encode_node(
26442644
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
26452645
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
26462646

2647-
const int64_t nrows_x = ggml_nrows(src0);
2648-
const int64_t nrows_y = src0->ne[1];
2649-
2650-
const uint32_t n_head = nrows_x/nrows_y;
2647+
const uint32_t n_head = src0->ne[2];
26512648
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
26522649

26532650
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2707,6 +2704,18 @@ static bool ggml_metal_encode_node(
27072704
/*.ne00 =*/ ne00,
27082705
/*.ne01 =*/ ne01,
27092706
/*.ne02 =*/ ne02,
2707+
/*.nb01 =*/ nb01,
2708+
/*.nb02 =*/ nb02,
2709+
/*.nb03 =*/ nb03,
2710+
/*.ne11 =*/ ne11,
2711+
/*.ne12 =*/ ne12,
2712+
/*.ne13 =*/ ne13,
2713+
/*.nb11 =*/ nb11,
2714+
/*.nb12 =*/ nb12,
2715+
/*.nb13 =*/ nb13,
2716+
/*.nb1 =*/ nb1,
2717+
/*.nb2 =*/ nb2,
2718+
/*.nb3 =*/ nb3,
27102719
/*.scale =*/ scale,
27112720
/*.max_bias =*/ max_bias,
27122721
/*.m0 =*/ m0,
@@ -2726,7 +2735,7 @@ static bool ggml_metal_encode_node(
27262735

27272736
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
27282737

2729-
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2738+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
27302739
} break;
27312740
case GGML_OP_DIAG_MASK_INF:
27322741
{
@@ -4979,7 +4988,9 @@ static bool ggml_metal_encode_node(
49794988
/*.nb21 =*/ nb21,
49804989
/*.nb22 =*/ nb22,
49814990
/*.nb23 =*/ nb23,
4991+
/*.ne32 =*/ ne32,
49824992
/*.nb31 =*/ nb31,
4993+
/*.nb32 =*/ nb32,
49834994
/*.ne1 =*/ ne1,
49844995
/*.ne2 =*/ ne2,
49854996
/*.scale =*/ scale,

0 commit comments

Comments
 (0)