@@ -112,9 +112,17 @@ llama_context::llama_context(
112112 }
113113 }
114114
115+ cparams.n_ctx_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max ;
116+
117+ if (cparams.n_ctx_seq > hparams.n_ctx_train ) {
118+ LLAMA_LOG_WARN (" %s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n " , __func__, cparams.n_ctx_seq , hparams.n_ctx_train );
119+
120+ cparams.n_ctx_seq = hparams.n_ctx_train ;
121+ }
122+
115123 LLAMA_LOG_INFO (" %s: n_seq_max = %u\n " , __func__, cparams.n_seq_max );
116124 LLAMA_LOG_INFO (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
117- LLAMA_LOG_INFO (" %s: n_ctx_per_seq = %u\n " , __func__, n_ctx_per_seq () );
125+ LLAMA_LOG_INFO (" %s: n_ctx_seq = %u\n " , __func__, cparams. n_ctx_seq );
118126 LLAMA_LOG_INFO (" %s: n_batch = %u\n " , __func__, cparams.n_batch );
119127 LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
120128 LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
@@ -123,14 +131,14 @@ llama_context::llama_context(
123131 LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
124132 LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
125133
126- if (n_ctx_per_seq () < hparams.n_ctx_train ) {
127- LLAMA_LOG_WARN (" %s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n " ,
128- __func__, n_ctx_per_seq () , hparams.n_ctx_train );
134+ if (cparams. n_ctx_seq < hparams.n_ctx_train ) {
135+ LLAMA_LOG_WARN (" %s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n " ,
136+ __func__, cparams. n_ctx_seq , hparams.n_ctx_train );
129137 }
130138
131- if (n_ctx_per_seq () > hparams.n_ctx_train ) {
132- LLAMA_LOG_WARN (" %s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n " ,
133- __func__, n_ctx_per_seq () , hparams.n_ctx_train );
139+ if (cparams. n_ctx_seq > hparams.n_ctx_train ) {
140+ LLAMA_LOG_WARN (" %s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n " ,
141+ __func__, cparams. n_ctx_seq , hparams.n_ctx_train );
134142 }
135143
136144 if (!hparams.vocab_only ) {
@@ -451,8 +459,8 @@ uint32_t llama_context::n_ctx() const {
451459 return cparams.n_ctx ;
452460}
453461
454- uint32_t llama_context::n_ctx_per_seq () const {
455- return cparams.kv_unified ? cparams. n_ctx : cparams. n_ctx / cparams. n_seq_max ;
462+ uint32_t llama_context::n_ctx_seq () const {
463+ return cparams.n_ctx_seq ;
456464}
457465
458466uint32_t llama_context::n_batch () const {
@@ -2381,6 +2389,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
23812389 return ctx->n_ctx ();
23822390}
23832391
2392+ uint32_t llama_n_ctx_seq (const llama_context * ctx) {
2393+ return ctx->n_ctx_seq ();
2394+ }
2395+
23842396uint32_t llama_n_batch (const llama_context * ctx) {
23852397 return ctx->n_batch ();
23862398}
0 commit comments