Skip to content

llama : add high-throughput mode #14363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 29 commits into
base: gg/kv-cache-use-set-rows
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c1a581a
ggml : add ggml_set_rows
rgerganov Jun 19, 2025
f2cd962
use I64 for indices
rgerganov Jun 20, 2025
695b6b7
ggml : add repeat impl for i64
ggerganov Jun 21, 2025
313a444
ggml : add ggml_is_contiguous_rows
ggerganov Jun 22, 2025
df71c80
ggml : ggml_set_rows support broadcast
ggerganov Jun 22, 2025
630c84a
ggml : ggml_set_rows support quantized dst
ggerganov Jun 22, 2025
e897097
ggml : support GGML_TYPE_F32 ".from_float" trait
ggerganov Jun 22, 2025
e73690a
ggml : ggml_set_rows update comment + better index name
ggerganov Jun 22, 2025
828e5d2
tests : add ggml_set_rows
ggerganov Jun 22, 2025
c0cfc2f
metal : add ggml_set_rows implementation
ggerganov Jun 22, 2025
eba9757
ggml : simplify forward_dup_f32
rgerganov Jun 23, 2025
1f647b5
ggml : fix supports_op
rgerganov Jun 23, 2025
79dac3c
kv-cache : use ggml_set_rows
ggerganov Jun 19, 2025
db2bb37
cont : gate the ggml_set_rows usage with env var
ggerganov Jun 21, 2025
f875d6c
cont : migrate to using set of indices instead of slot head
ggerganov Jun 21, 2025
39d0b1e
cont : kv-cells cp/set for non-cont slots
ggerganov Jun 21, 2025
332f073
cont : support non-continuous slots
ggerganov Jun 21, 2025
36f8e20
kv-cache : utilize ggml_set_rows broadcast
ggerganov Jun 22, 2025
52b9007
llama : add "virtual sequences"
ggerganov Jun 23, 2025
1321439
tools : tmp adjustments (TMP)
ggerganov Jun 24, 2025
401c13e
cont : fix build
ggerganov Jun 24, 2025
7c6487b
metal : extend ggml_soft_max_ext() to support n_seq dim
ggerganov Jun 24, 2025
8c68219
kv-cache : fix non-FA path with virutal sequences
ggerganov Jun 24, 2025
1b74b9d
ggml : extend support for n_seq for soft_max and fattn
ggerganov Jun 24, 2025
165d822
graph : support iSWA virtual sequences
ggerganov Jun 24, 2025
0bb1da5
kv-cache : simplify set_rows logic
ggerganov Jun 24, 2025
6663128
kv-cache : rework kv_idxs, support seq_cp
ggerganov Jun 25, 2025
5eb1a88
batch : optional requirement for sequential sequence ids
ggerganov Jun 25, 2025
6179578
batch : require non-coupled batch with sequential split_equal
ggerganov Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i];
} else if (type == GGML_TYPE_I64) {
v = (float) *(int64_t *) &data[i];
} else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) {
Expand Down
3 changes: 2 additions & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ int main(int argc, char ** argv) {

// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);

int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
Expand Down Expand Up @@ -289,6 +289,7 @@ int main(int argc, char ** argv) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_memory_seq_rm(mem, i, -1, -1);

// but keep the system prompt
llama_memory_seq_cp(mem, 0, i, -1, -1);
}
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml-cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ extern "C" {

GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);

GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
Expand Down
21 changes: 21 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ extern "C" {
GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS,
GGML_OP_GET_ROWS_BACK,
GGML_OP_SET_ROWS,
GGML_OP_DIAG,
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
Expand Down Expand Up @@ -687,6 +688,9 @@ extern "C" {
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);

// true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);

GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);

Expand Down Expand Up @@ -1375,6 +1379,23 @@ extern "C" {
struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape

// a TD [n_embd, ne1, ne2, ne3]
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
//
// undefined behavior if destination rows overlap
//
// broadcast:
// ne2 % ne11 == 0
// ne3 % ne12 == 0
//
// return view(a)
GGML_API struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a, // destination
struct ggml_tensor * b, // source
struct ggml_tensor * c); // row indices

GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a);
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ typedef pthread_t ggml_thread_t;

static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
Expand Down Expand Up @@ -1814,6 +1815,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_get_rows_back(params, tensor);
} break;
case GGML_OP_SET_ROWS:
{
ggml_compute_forward_set_rows(params, tensor);
} break;
case GGML_OP_DIAG:
{
ggml_compute_forward_diag(params, tensor);
Expand Down Expand Up @@ -2167,6 +2172,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = n_threads;
} break;
case GGML_OP_GET_ROWS:
case GGML_OP_SET_ROWS:
{
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
// decreases performance with GPU offloading
Expand Down Expand Up @@ -3121,6 +3127,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
return ggml_graph_compute(cgraph, &cplan);
}

void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
memcpy(y, x, n * sizeof(float));
}

void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
int64_t i = 0;
#if defined(__F16C__)
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ggml-cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st

switch (op->op) {
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
return
op->type != GGML_TYPE_IQ3_XXS &&
op->type != GGML_TYPE_IQ3_S &&
Expand Down
153 changes: 130 additions & 23 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
if (ggml_is_contiguous(dst)) {
// TODO: simplify
if (nb00 == sizeof(float)) {
if (dst->type == GGML_TYPE_F32) {
size_t id = 0;
const size_t rs = ne00 * nb00;
char * dst_ptr = (char *) dst->data;

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
memcpy(dst_ptr + id, src0_ptr, rs);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;

size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
Expand All @@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
from_float(src0_ptr, dst_ptr + id, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
Expand Down Expand Up @@ -2282,6 +2266,52 @@ static void ggml_compute_forward_repeat_f16(
}
}

static void ggml_compute_forward_repeat_i64(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

if (params->ith != 0) {
return;
}

GGML_ASSERT(ggml_can_repeat(src0, dst));

GGML_TENSOR_UNARY_OP_LOCALS

// guaranteed to be an integer due to the check in ggml_can_repeat
const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03);

// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(int64_t));
GGML_ASSERT(nb00 == sizeof(int64_t));

// TODO: maybe this is not optimal?
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne03; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne02; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne01; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
for (int i = 0; i < ne00; ++i) {
y[i] = x[i];
}
}
}
}
}
}
}
}
}

void ggml_compute_forward_repeat(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand All @@ -2300,6 +2330,10 @@ void ggml_compute_forward_repeat(
{
ggml_compute_forward_repeat_f32(params, dst);
} break;
case GGML_TYPE_I64:
{
ggml_compute_forward_repeat_i64(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -4470,6 +4504,74 @@ void ggml_compute_forward_get_rows(
//}
}

static void ggml_compute_forward_set_rows_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_TENSOR_BINARY_OP_LOCALS

const int64_t nc = ne00;
const int64_t nr = ne01;

assert(ne0 == nc);
assert(ne2 == ne02);
assert(ne3 == ne03);
assert(src0->type == GGML_TYPE_F32);
assert(ne02 % ne11 == 0);
assert(ne03 % ne12 == 0);

const int ith = params->ith;
const int nth = params->nth;

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;

for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i03%ne12;
const int64_t i11 = i02%ne11;
const int64_t i10 = i;

const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);

GGML_ASSERT(i1 >= 0 && i1 < ne1);

from_float(
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
}
}
}
}

void ggml_compute_forward_set_rows(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_set_rows_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_get_rows_back

static void ggml_compute_forward_get_rows_back_f32_f16(
Expand Down Expand Up @@ -4751,7 +4853,8 @@ static void ggml_compute_forward_soft_max_f32(

GGML_TENSOR_UNARY_OP_LOCALS

//const int64_t ne11 = src1 ? src1->ne[1] : 1;
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;

// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
Expand All @@ -4776,6 +4879,10 @@ static void ggml_compute_forward_soft_max_f32(
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);

for (int i1 = ir0; i1 < ir1; i1++) {
const int64_t i11 = (i1%ne01);
//const int64_t i12 = (i1/ne01)%ne02;
const int64_t i13 = (i1/ne01)/ne02;

// ALiBi
const uint32_t h = (i1/ne01)%ne02; // head
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
Expand All @@ -4784,8 +4891,8 @@ static void ggml_compute_forward_soft_max_f32(
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);

// broadcast the mask across rows
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;

ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
Expand Down Expand Up @@ -7125,7 +7232,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memset(VKQ32, 0, DV*sizeof(float));
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + iq3*mask->nb[2]) : NULL;

// k indices
const int ik3 = iq3 / rk3;
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
19 changes: 19 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ typedef struct {
uint64_t nb22;
uint64_t nb23;
uint64_t nb31;
uint64_t nb32;
int32_t ne1;
int32_t ne2;
float scale;
Expand Down Expand Up @@ -453,6 +454,8 @@ typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb11;
uint64_t nb12;
float scale;
float max_bias;
float m0;
Expand Down Expand Up @@ -521,6 +524,22 @@ typedef struct {
uint64_t nb2;
} ggml_metal_kargs_get_rows;

typedef struct {
int32_t nk0;
int32_t ne01;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_set_rows;

typedef struct {
int64_t ne00;
int64_t ne01;
Expand Down
Loading
Loading