@@ -754,13 +754,13 @@ struct server_context {
754
754
default_generation_settings_for_props = get_formated_generation (slots.front ());
755
755
default_generation_settings_for_props[" seed" ] = -1 ;
756
756
757
- // the update_slots() logic will always submit a maximum of n_batch tokens
757
+ // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
758
758
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
759
759
{
760
760
const int32_t n_batch = llama_n_batch (ctx);
761
761
762
762
// only a single seq_id per token is needed
763
- batch = llama_batch_init (n_batch, 0 , 1 );
763
+ batch = llama_batch_init (std::max ( n_batch, params. n_parallel ) , 0 , 1 );
764
764
}
765
765
766
766
metrics.init ();
@@ -1137,28 +1137,19 @@ struct server_context {
1137
1137
if (!system_prompt.empty ()) {
1138
1138
system_tokens = ::llama_tokenize (ctx, system_prompt, true );
1139
1139
1140
- llama_batch_clear (batch);
1140
+ const int32_t n_batch = llama_n_batch (ctx);
1141
+ const int32_t n_tokens_prompt = system_tokens.size ();
1141
1142
1142
- for (int i = 0 ; i < (int )system_tokens.size (); ++i) {
1143
- llama_batch_add (batch, system_tokens[i], i, { 0 }, false );
1144
- }
1143
+ for (int32_t i = 0 ; i < n_tokens_prompt; i += n_batch) {
1144
+ const int32_t n_tokens = std::min (n_batch, n_tokens_prompt - i);
1145
1145
1146
- const int32_t n_batch = llama_n_batch (ctx );
1146
+ llama_batch_clear (batch );
1147
1147
1148
- for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
1149
- const int32_t n_tokens = std::min (params.n_batch , batch.n_tokens - i);
1150
- llama_batch batch_view = {
1151
- n_tokens,
1152
- batch.token + i,
1153
- nullptr ,
1154
- batch.pos + i,
1155
- batch.n_seq_id + i,
1156
- batch.seq_id + i,
1157
- batch.logits + i,
1158
- 0 , 0 , 0 , // unused
1159
- };
1148
+ for (int32_t j = 0 ; j < n_tokens; ++j) {
1149
+ llama_batch_add (batch, system_tokens[i + j], i + j, { 0 }, false );
1150
+ }
1160
1151
1161
- if (llama_decode (ctx, batch_view ) != 0 ) {
1152
+ if (llama_decode (ctx, batch ) != 0 ) {
1162
1153
LOG_ERROR (" llama_decode() failed" , {});
1163
1154
return ;
1164
1155
}
0 commit comments