Skip to content

Conversation

@adabeyta
Copy link
Contributor

@adabeyta adabeyta commented Oct 21, 2025

Purpose:

Resolves bug #27102 .

Test Plan:

Throughput

vllm bench throughput   --model Qwen/Qwen3-8B   --quantization fp8   --kv-cache-dtype fp8_e4m3   --tensor-parallel-size 2   --dataset-name random --input-len 1024 --output-len 256

E2E Correctness
Server with KV scale calculation ON (remove --calculate-kv-scales flag for OFF case):

vllm serve Qwen/Qwen3-8B \
    --tensor-parallel-size 2 \
    --quantization fp8 \
    --kv-cache-dtype fp8_e4m3 \
    --calculate-kv-scales \
    --port 8000

lm_eval

lm_eval --model local-completions --model_args pretrained <model>,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3 --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 100 

Throughput results

Main:calculate_kv_scales=False | enforce_eager=False

Throughput: 1.73 requests/s, 2210.02 total tokens/s, 442.00 output tokens/s
Total num prompt tokens:  1024000
Total num output tokens:  256000


Main:calculate_kv_scales=True | enforce_eager=False

Throughput: 1.73 requests/s, 2211.36 total tokens/s, 442.27 output tokens/s
Total num prompt tokens:  1024000
Total num output tokens:  256000

Main: calulate_kv_scales=True | enforce_eager=True

Throughput: 1.72 requests/s, 2202.42 total tokens/s, 440.48 output tokens/s
Total num prompt tokens:  1024000
Total num output tokens:  256000

PR: calculate_kv_scales=False | enforce_eager=False

Throughput: 1.73 requests/s, 2211.37 total tokens/s, 442.27 output tokens/s
Total num prompt tokens:  1024000
Total num output tokens:  256000

PR: calculate_kv_scales=True | enforce_eager=False

Throughput: 1.73 requests/s, 2210.78 total tokens/s, 442.16 output tokens/s
Total num prompt tokens:  1024000
Total num output tokens:  256000

PR: calulate_kv_scales=True | enforce_eager=True

Throughput: 1.72 requests/s, 2200.57 total tokens/s, 440.11 output tokens/s
Total num prompt tokens:  1024000
Total num output tokens:  256000

E2E GSM8K Accuracy Results

Main: calulate_kv_scales=False | enforce_eager=False

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.84|±  |0.0368|
|     |       |strict-match    |     5|exact_match|↑  | 0.83|±  |0.0378|```

Main: calulate_kv_scales=True | enforce_eager=False

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.85|±  |0.0359|
|     |       |strict-match    |     5|exact_match|↑  | 0.86|±  |0.0349|

Main: calulate_kv_scales=True | enforce_eager=True

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.88|±  |0.0327|
|     |       |strict-match    |     5|exact_match|↑  | 0.89|±  |0.0314|

PR: calculate_kv_scales=False | enforce_eager=False

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.86|±  |0.0349|
|     |       |strict-match    |     5|exact_match|↑  | 0.86|±  |0.0349|

PR: calculate_kv_scales=True | enforce_eager=False

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.86|±  |0.0349|
|     |       |strict-match    |     5|exact_match|↑  | 0.86|±  |0.0349|

PR: calulate_kv_scales=True | enforce_eager=True

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.89|±  |0.0314|
|     |       |strict-match    |     5|exact_match|↑  | 0.90|±  |0.0302|

@mergify mergify bot added the v1 label Oct 21, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug where calculated KV scales were not being applied during attention. The fix introduces a mechanism to calculate these scales on the first forward pass and then disables subsequent calculations by adding a state flag kv_scales_calculated to GPUModelRunner. The logic is sound and the changes are well-targeted. I've included one suggestion to refactor the implementation slightly, which will improve maintainability by removing a small piece of redundant logic.

Comment on lines 2538 to 2540
if (self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated):
self.kv_scales_calculated = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To improve maintainability and avoid re-evaluating the same condition, you can pass the enable_kv_scales_calculation flag from _prepare_inputs to execute_model. This makes the logic clearer and reduces redundancy.

Here's how you can do it:

  1. Update _prepare_inputs to return enable_kv_scales_calculation:

    # In _prepare_inputs function signature
    ) -> tuple[
        ...,
        bool,  # use_cascade_attn
        bool,  # enable_kv_scales_calculation
    ]:
    
    # In _prepare_inputs return statement
    return (
        ...,
        use_cascade_attn,
        enable_kv_scales_calculation,
    )
  2. Update the call to _prepare_inputs in execute_model:

    # In execute_model
    (
        ...,
        use_cascade_attn,
        enable_kv_scales_calculation,
    ) = self._prepare_inputs(scheduler_output)
  3. Then, you can simplify the logic for updating self.kv_scales_calculated as suggested below.

# Mark KV scales as calculated if they were computed in this pass.
if enable_kv_scales_calculation:
    self.kv_scales_calculated = True

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 1355 to 1359
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
enable_kv_scales_calculation=enable_kv_scales_calculation,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Propagate KV scale flag to per-layer metadata

The new enable_kv_scales_calculation flag is set only on CommonAttentionMetadata here, but the per-layer metadata objects that Attention.forward actually receives are produced later by builders (split_attn_metadata and the various *AttentionMetadataBuilder.build) without copying this attribute. As a result getattr(attn_metadata, "enable_kv_scales_calculation", False) in vllm/attention/layer.py remains False and the KV scale calculation path never runs, so the bug this change is meant to fix still occurs whenever KV scales are enabled. The flag needs to be attached to the per-layer metadata returned by the builders (and preserved when splitting) so layers can see it.

Useful? React with 👍 / 👎.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the fix, I got a suggestion to further simplify the logic

dcp_local_seq_lens: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""

enable_kv_scales_calculation: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make sure this is removed from backend-specific attention metadata classes, if anywhere?

self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len

self.kv_scales_calculated = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.kv_scales_calculated = False
# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales

Comment on lines 1333 to 1337
# Determine if we need to calculate KV scales on this forward pass.
# Only True on the first pass when calculate_kv_scales is enabled.
enable_kv_scales_calculation = (
self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Determine if we need to calculate KV scales on this forward pass.
# Only True on the first pass when calculate_kv_scales is enabled.
enable_kv_scales_calculation = (
self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated)

dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
enable_kv_scales_calculation=enable_kv_scales_calculation,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enable_kv_scales_calculation=enable_kv_scales_calculation,
enable_kv_scales_calculation=self.calculate_kv_scales,

Comment on lines 2538 to 2540
if (self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated):
self.kv_scales_calculated = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even easier:

Suggested change
if (self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated):
self.kv_scales_calculated = True
self.calculate_kv_scales = False

@ProExpertProg
Copy link
Collaborator

Could you just sanity check with a deepseek model that it works with MLA too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants