@@ -2407,7 +2407,7 @@ struct server_context {
24072407
24082408 params_dft.devices = params_base.speculative .devices ;
24092409 params_dft.model = params_base.speculative .model ;
2410- params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? params_base. n_ctx / params_base. n_parallel : params_base.speculative .n_ctx ;
2410+ params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? slots. front (). n_ctx : params_base.speculative .n_ctx ;
24112411 params_dft.n_gpu_layers = params_base.speculative .n_gpu_layers ;
24122412 params_dft.n_parallel = 1 ;
24132413 params_dft.cache_type_k = params_base.speculative .cache_type_k ;
@@ -2495,7 +2495,7 @@ struct server_context {
24952495 }
24962496
24972497 void init () {
2498- const int32_t n_ctx_slot = n_ctx / params_base.n_parallel ;
2498+ const int32_t n_ctx_slot = params_base. kv_unified ? n_ctx : n_ctx / params_base.n_parallel ;
24992499
25002500 SRV_INF (" initializing slots, n_slots = %d\n " , params_base.n_parallel );
25012501
@@ -2699,6 +2699,36 @@ struct server_context {
26992699 return ret;
27002700 }
27012701
2702+ // return true if at least one slot has been purged
2703+ // TODO: improve logic
2704+ // - smarter decision which slot to purge
2705+ // - move slot to level 2 cache instead of removing?
2706+ // - instead of purging, try to store and resume later?
2707+ bool try_purge_idle_slots () {
2708+ bool res = false ;
2709+
2710+ if (!params_base.kv_unified ) {
2711+ return res;
2712+ }
2713+
2714+ for (auto & slot : slots) {
2715+ if (slot.is_processing ()) {
2716+ continue ;
2717+ }
2718+
2719+ if (slot.prompt .n_tokens () > 0 ) {
2720+ SRV_WRN (" purging slot %d with %zu tokens\n " , slot.id , slot.prompt .tokens .size ());
2721+
2722+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
2723+ slot.prompt .tokens .clear ();
2724+
2725+ res = true ;
2726+ }
2727+ }
2728+
2729+ return res;
2730+ }
2731+
27022732 bool launch_slot_with_task (server_slot & slot, server_task && task) {
27032733 slot.reset ();
27042734
@@ -3635,9 +3665,10 @@ struct server_context {
36353665 int32_t n_batch = llama_n_batch (ctx);
36363666 int32_t n_ubatch = llama_n_ubatch (ctx);
36373667
3638- // next, batch any pending prompts without exceeding n_batch
3639- float alora_scale = -1 .0f ;
3668+ float alora_scale = -1 .0f ;
36403669 size_t alora_disabled_id = 0 ;
3670+
3671+ // next, batch any pending prompts without exceeding n_batch
36413672 if (params_base.cont_batching || batch.n_tokens == 0 ) {
36423673 for (auto & slot : slots) {
36433674 // check if we can batch this slot with the previous one
@@ -4126,6 +4157,8 @@ struct server_context {
41264157 std::string err;
41274158
41284159 if (n_batch == 1 && ret == 1 ) {
4160+ // TODO: try to terminate only the largest active slot and continue
4161+ // need to remove the tokens from the current batch too
41294162 err = " Context size has been exceeded." ;
41304163 }
41314164
@@ -4141,17 +4174,23 @@ struct server_context {
41414174 // TODO: handle ret == 2 (abort) when we start aborting
41424175
41434176 if (!err.empty ()) {
4144- SRV_ERR (" %s, i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4177+ SRV_ERR (" %s i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4178+
41454179 for (auto & slot : slots) {
4146- send_error (slot, err);
4147- slot.release ();
4180+ if (slot.is_processing ()) {
4181+ send_error (slot, err);
4182+ slot.release ();
4183+ }
41484184 }
4185+
41494186 break ;
41504187 }
41514188 }
41524189
41534190 // retry with half the batch size to try to find a free slot in the KV cache
4154- n_batch /= 2 ;
4191+ if (!try_purge_idle_slots ()) {
4192+ n_batch /= 2 ;
4193+ }
41554194
41564195 SRV_WRN (" failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n " , i, n_batch, ret);
41574196
@@ -4944,7 +4983,7 @@ int main(int argc, char ** argv) {
49444983 // Everything else, including multimodal completions.
49454984 inputs = tokenize_input_prompts (ctx_server.vocab , ctx_server.mctx , prompt, true , true );
49464985 }
4947- const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server. params_base . n_parallel ;
4986+ const size_t n_ctx_slot = ctx_server.slots . front (). n_ctx ;
49484987 tasks.reserve (inputs.size ());
49494988 for (size_t i = 0 ; i < inputs.size (); i++) {
49504989 auto n_prompt_tokens = inputs[i].size ();
0 commit comments