@@ -1023,9 +1023,6 @@ struct llama_kv_cache {
1023
1023
uint32_t head = 0 ;
1024
1024
uint32_t size = 0 ;
1025
1025
1026
- // largest index of an occupied cell (used for a basic optimization heuristic)
1027
- uint32_t cell_max = 0 ;
1028
-
1029
1026
std::vector<llama_kv_cell> cells;
1030
1027
1031
1028
struct ggml_tensor * k = NULL ;
@@ -1229,8 +1226,6 @@ static bool llama_kv_cache_init(
1229
1226
cache.head = 0 ;
1230
1227
cache.size = n_ctx;
1231
1228
1232
- cache.cell_max = 0 ;
1233
-
1234
1229
cache.cells .clear ();
1235
1230
cache.cells .resize (n_ctx);
1236
1231
@@ -1316,15 +1311,16 @@ static bool llama_kv_cache_find_slot(
1316
1311
return true ;
1317
1312
}
1318
1313
1319
- void llama_kv_cache_update (struct llama_kv_cache & cache) {
1320
- // compute new cell_max
1321
- cache.cell_max = 0 ;
1314
+ int32_t llama_kv_cache_cell_max (const struct llama_kv_cache & cache) {
1315
+ int32_t res = 0 ;
1322
1316
1323
1317
for (uint32_t i = 0 ; i < cache.size ; i++) {
1324
- if (cache.cells [i].pos >= 0 ) {
1325
- cache. cell_max = i + 1 ;
1318
+ if (cache.cells [i].pos >= 0 && !cache. cells [i]. seq_id . empty () ) {
1319
+ res = i + 1 ;
1326
1320
}
1327
1321
}
1322
+
1323
+ return res;
1328
1324
}
1329
1325
1330
1326
void llama_kv_cache_rm_tokens (struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
@@ -1335,8 +1331,6 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
1335
1331
cache.cells [i].pos = -1 ;
1336
1332
cache.cells [i].seq_id .clear ();
1337
1333
}
1338
-
1339
- llama_kv_cache_update (cache);
1340
1334
}
1341
1335
1342
1336
void llama_kv_cache_rm_seq (struct llama_kv_cache & cache, llama_seq_id seq_id) {
@@ -1348,8 +1342,6 @@ void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
1348
1342
}
1349
1343
}
1350
1344
}
1351
-
1352
- llama_kv_cache_update (cache);
1353
1345
}
1354
1346
1355
1347
void llama_kv_cache_keep_seq (struct llama_kv_cache & cache, llama_seq_id seq_id) {
@@ -1359,8 +1351,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
1359
1351
cache.cells [i].seq_id .clear ();
1360
1352
}
1361
1353
}
1354
+ }
1355
+
1356
+ void llama_kv_cache_shift (
1357
+ struct llama_context & ctx,
1358
+ llama_seq_id seq_id,
1359
+ llama_pos p0,
1360
+ llama_pos p1,
1361
+ llama_pos delta) {
1362
+ auto & hparams = ctx.model .hparams ;
1363
+ auto & cache = ctx.kv_self ;
1362
1364
1363
- llama_kv_cache_update (cache);
1365
+ for (uint32_t i = 0 ; i < cache.size ; ++i) {
1366
+ if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1367
+ cache.cells [i].pos += delta;
1368
+ }
1369
+ }
1364
1370
}
1365
1371
1366
1372
//
@@ -2587,7 +2593,7 @@ static struct ggml_cgraph * llm_build_llama(
2587
2593
const int n_gpu_layers = model.n_gpu_layers ;
2588
2594
2589
2595
const int32_t n_tokens = batch.n_tokens ;
2590
- const int32_t n_kv = kv_self. cell_max + n_tokens ;
2596
+ const int32_t n_kv = llama_kv_cache_cell_max ( kv_self) ;
2591
2597
2592
2598
auto & buf_compute = lctx.buf_compute ;
2593
2599
@@ -2678,13 +2684,6 @@ static struct ggml_cgraph * llm_build_llama(
2678
2684
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
2679
2685
}
2680
2686
}
2681
-
2682
- // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
2683
- for (int i = n_kv; i < n_ctx; ++i) {
2684
- if (kv_self.cells [i].has_seq_id (seq_id) && kv_self.cells [i].pos >= 0 ) {
2685
- GGML_ASSERT (false && " cell_max is too small - this might indicate a bug" );
2686
- }
2687
- }
2688
2687
}
2689
2688
}
2690
2689
}
@@ -2952,7 +2951,7 @@ static struct ggml_cgraph * llm_build_baichaun(
2952
2951
const int n_gpu_layers = model.n_gpu_layers ;
2953
2952
2954
2953
const int32_t n_tokens = batch.n_tokens ;
2955
- const int32_t n_kv = kv_self. cell_max + n_tokens ;
2954
+ const int32_t n_kv = llama_kv_cache_cell_max ( kv_self) ;
2956
2955
2957
2956
auto & buf_compute = lctx.buf_compute ;
2958
2957
@@ -3043,13 +3042,6 @@ static struct ggml_cgraph * llm_build_baichaun(
3043
3042
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3044
3043
}
3045
3044
}
3046
-
3047
- // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3048
- for (int i = n_kv; i < n_ctx; ++i) {
3049
- if (kv_self.cells [i].has_seq_id (seq_id) && kv_self.cells [i].pos >= 0 ) {
3050
- GGML_ASSERT (false && " cell_max is too small - this might indicate a bug" );
3051
- }
3052
- }
3053
3045
}
3054
3046
}
3055
3047
}
@@ -3334,7 +3326,7 @@ static struct ggml_cgraph * llm_build_falcon(
3334
3326
const int n_gpu_layers = model.n_gpu_layers ;
3335
3327
3336
3328
const int32_t n_tokens = batch.n_tokens ;
3337
- const int32_t n_kv = kv_self. cell_max + n_tokens ;
3329
+ const int32_t n_kv = llama_kv_cache_cell_max ( kv_self) ;
3338
3330
3339
3331
auto & buf_compute = lctx.buf_compute ;
3340
3332
@@ -3425,13 +3417,6 @@ static struct ggml_cgraph * llm_build_falcon(
3425
3417
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3426
3418
}
3427
3419
}
3428
-
3429
- // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3430
- for (int i = n_kv; i < n_ctx; ++i) {
3431
- if (kv_self.cells [i].has_seq_id (seq_id) && kv_self.cells [i].pos >= 0 ) {
3432
- GGML_ASSERT (false && " cell_max is too small - this might indicate a bug" );
3433
- }
3434
- }
3435
3420
}
3436
3421
}
3437
3422
}
@@ -3671,7 +3656,7 @@ static struct ggml_cgraph * llm_build_starcoder(
3671
3656
const float norm_eps = hparams.f_norm_eps ;
3672
3657
3673
3658
const int32_t n_tokens = batch.n_tokens ;
3674
- const int32_t n_kv = kv_self. cell_max + n_tokens ;
3659
+ const int32_t n_kv = llama_kv_cache_cell_max ( kv_self) ;
3675
3660
3676
3661
auto & buf_compute = lctx.buf_compute ;
3677
3662
@@ -3754,13 +3739,6 @@ static struct ggml_cgraph * llm_build_starcoder(
3754
3739
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3755
3740
}
3756
3741
}
3757
-
3758
- // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
3759
- for (int i = n_kv; i < n_ctx; ++i) {
3760
- if (kv_self.cells [i].has_seq_id (seq_id) && kv_self.cells [i].pos >= 0 ) {
3761
- GGML_ASSERT (false && " cell_max is too small - this might indicate a bug" );
3762
- }
3763
- }
3764
3742
}
3765
3743
}
3766
3744
}
@@ -4055,8 +4033,7 @@ static bool llama_eval_internal(
4055
4033
#endif
4056
4034
4057
4035
// update the kv ring buffer
4058
- lctx.kv_self .head += n_tokens;
4059
- lctx.kv_self .cell_max = std::max (lctx.kv_self .cell_max , lctx.kv_self .head );
4036
+ lctx.kv_self .head += n_tokens;
4060
4037
4061
4038
#ifdef GGML_PERF
4062
4039
// print timing information per ggml operation (for debugging purposes)
@@ -6834,6 +6811,10 @@ void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
6834
6811
llama_kv_cache_keep_seq (ctx->kv_self , seq_id);
6835
6812
}
6836
6813
6814
+ void llama_kv_cache_shift (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815
+ llama_kv_cache_shift (*ctx, seq_id, p0, p1, delta);
6816
+ }
6817
+
6837
6818
// Returns the *maximum* size of the state
6838
6819
size_t llama_get_state_size (const struct llama_context * ctx) {
6839
6820
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
@@ -7130,8 +7111,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
7130
7111
7131
7112
ctx->kv_self .head = kv_ntok;
7132
7113
ctx->kv_self .size = kv_size;
7133
-
7134
- ctx->kv_self .cell_max = kv_ntok;
7135
7114
}
7136
7115
7137
7116
const size_t nread = inp - src;
0 commit comments