Skip to content

llama : reuse compute graphs #14482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--graph-reuse", "-gr"},
string_format("reuse previous compute graphs when possible (default: %s)"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14482)", params.graph_reuse ? "true" : "false"),
[](common_params & params) {
params.graph_reuse = true;
}
).set_env("LLAMA_ARG_GRAPH_REUSE"));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
cparams.graph_reuse = params.graph_reuse;

cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ struct common_params {
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool graph_reuse = false; // reuse previous compute graphs when possible

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ extern "C" {
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573

bool graph_reuse; // reuse previous compute graphs when possible
};

// model quantization parameters
Expand Down Expand Up @@ -1429,6 +1431,7 @@ extern "C" {

int32_t n_p_eval;
int32_t n_eval;
int32_t n_reused;
};

struct llama_perf_sampler_data {
Expand Down
25 changes: 25 additions & 0 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ struct llama_ubatch {
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
int8_t * output; // [n_tokens] | i | -

bool is_same(const llama_ubatch & other) const {
bool res =
equal_seqs == other.equal_seqs &&
n_tokens == other.n_tokens &&
n_seq_tokens == other.n_seq_tokens &&
n_seqs == other.n_seqs &&
n_seqs_unq == other.n_seqs_unq &&
(
(!token && !other.token) ||
(!embd && !other.embd)
);

if (!res) {
return false;
}

// TODO: this won't work because seq_id_unq ptr can point to an old balloc that has
// been freed by this point. find a way to fix this
//for (uint32_t s = 0; s < n_seqs_unq; ++s) {
// res &= seq_id_unq[s] == other.seq_id_unq[s];
//}

return res;
}
};

// a helper for sanitizing, fulfilling and splitting a batch
Expand Down
170 changes: 88 additions & 82 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ llama_context::llama_context(

cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);

cparams.op_offload = params.op_offload;
cparams.op_offload = params.op_offload;
cparams.graph_reuse = params.graph_reuse;

const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;

Expand Down Expand Up @@ -227,8 +228,8 @@ llama_context::llama_context(

LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);

// buffer used to store the computation graph and the tensor meta data
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
gf_res_prev.reset(new llm_graph_result(max_nodes));
gf_res_reserve.reset(new llm_graph_result(max_nodes));

// TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
Expand Down Expand Up @@ -388,10 +389,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
return sched.get();
}

ggml_context * llama_context::get_ctx_compute() const {
return ctx_compute.get();
}

uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
Expand Down Expand Up @@ -678,38 +675,52 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}

llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}

auto * gf = graph_init();
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
auto * res = gf_res_prev.get();
auto * gf = res->get_gf();

auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
// the new graph parameters
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
const auto gparams = graph_params(res, ubatch, mctx, gtype);

// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
const bool can_reuse = cparams.graph_reuse && res->update(gparams);
if (can_reuse) {
LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
n_reused++;
} else {
res->reset();

if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

//const auto t_start_us = ggml_time_us();

gf = model.build_graph(gparams);

//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);

if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}

if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}
}

res->set_inputs(&ubatch);

const auto status = graph_compute(gf, ubatch.n_tokens > 1);
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
ret = status;
Expand Down Expand Up @@ -767,9 +778,6 @@ int llama_context::encode(const llama_batch & batch_inp) {

n_outputs = n_tokens;

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

const auto causal_attn_org = cparams.causal_attn;

// always use non-causal attention for encoder graphs
Expand All @@ -778,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
cparams.causal_attn = false;

ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);

cparams.causal_attn = causal_attn_org;

Expand Down Expand Up @@ -846,7 +854,9 @@ int llama_context::encode(const llama_batch & batch_inp) {

// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
if (!cparams.graph_reuse) {
ggml_backend_sched_reset(sched.get());
}

// TODO: hacky solution
if (model.arch == LLM_ARCH_T5 && t_embd) {
Expand Down Expand Up @@ -1005,11 +1015,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
n_outputs = n_outputs_new;
}

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);

if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
Expand Down Expand Up @@ -1192,7 +1199,9 @@ int llama_context::decode(const llama_batch & batch_inp) {

// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
if (!cparams.graph_reuse) {
ggml_backend_sched_reset(sched.get());
}

return 0;
}
Expand Down Expand Up @@ -1275,20 +1284,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
// graph
//

int32_t llama_context::graph_max_nodes() const {
return std::max<int32_t>(65536, 5*model.n_tensors());
}

ggml_cgraph * llama_context::graph_init() {
ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};

ctx_compute.reset(ggml_init(params));

return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
uint32_t llama_context::graph_max_nodes() const {
return std::max<uint32_t>(65536u, 5u*model.n_tensors());
}

ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
Expand All @@ -1301,6 +1298,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
}

gf_res_prev->reset();
ggml_backend_sched_reset(sched.get());

// store the n_outputs as it is, and restore it afterwards
// TODO: not sure if needed, might simplify in the future by removing this
const auto save_n_outputs = this->n_outputs;
Expand All @@ -1310,17 +1310,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
auto * res = gf_res_reserve.get();

this->n_outputs = save_n_outputs;
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);

if (!res) {
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
return nullptr;
}
res->reset();

ggml_backend_sched_reset(sched.get());
auto * gf = model.build_graph(gparams);

this->n_outputs = save_n_outputs;

// initialize scheduler with the specified graph
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
Expand All @@ -1331,28 +1329,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
return gf;
}

llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx) {
return model.build_graph(
{
/*.ctx =*/ ctx,
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
}, gf, gtype);
llm_graph_params llama_context::graph_params(
llm_graph_result_i * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const {
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
};
}

ggml_status llama_context::graph_compute(
Expand Down Expand Up @@ -1930,6 +1927,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
data.t_eval_ms = 1e-3 * t_eval_us;
data.n_p_eval = std::max(1, n_p_eval);
data.n_eval = std::max(1, n_eval);
data.n_reused = std::max(0, n_reused);

return data;
}
Expand All @@ -1938,6 +1936,7 @@ void llama_context::perf_reset() {
t_start_us = ggml_time_us();
t_eval_us = n_eval = 0;
t_p_eval_us = n_p_eval = 0;
n_reused = 0;
}

//
Expand Down Expand Up @@ -2064,8 +2063,13 @@ void llama_context::opt_epoch_iter(
break;
}

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
auto * res = gf_res_prev.get();

const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);

res->reset();

auto * gf = model.build_graph(gparams);

struct ggml_context * ctx_compute_opt;
{
Expand Down Expand Up @@ -2187,6 +2191,7 @@ llama_context_params llama_context_default_params() {
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.graph_reuse =*/ false,
};

return result;
Expand Down Expand Up @@ -2807,6 +2812,7 @@ void llama_perf_context_print(const llama_context * ctx) {
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
}

void llama_perf_context_reset(llama_context * ctx) {
Expand Down
Loading
Loading