- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.8k
[Bugfix] Ensure calculated KV scales are applied in attention. #27232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: adabeyta <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a bug where calculated KV scales were not being applied during attention. The fix introduces a mechanism to calculate these scales on the first forward pass and then disables subsequent calculations by adding a state flag kv_scales_calculated to GPUModelRunner. The logic is sound and the changes are well-targeted. I've included one suggestion to refactor the implementation slightly, which will improve maintainability by removing a small piece of redundant logic.
        
          
                vllm/v1/worker/gpu_model_runner.py
              
                Outdated
          
        
      | if (self.cache_config.calculate_kv_scales | ||
| and not self.kv_scales_calculated): | ||
| self.kv_scales_calculated = True | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and avoid re-evaluating the same condition, you can pass the enable_kv_scales_calculation flag from _prepare_inputs to execute_model. This makes the logic clearer and reduces redundancy.
Here's how you can do it:
- 
Update _prepare_inputsto returnenable_kv_scales_calculation:# In _prepare_inputs function signature ) -> tuple[ ..., bool, # use_cascade_attn bool, # enable_kv_scales_calculation ]: # In _prepare_inputs return statement return ( ..., use_cascade_attn, enable_kv_scales_calculation, ) 
- 
Update the call to _prepare_inputsinexecute_model:# In execute_model ( ..., use_cascade_attn, enable_kv_scales_calculation, ) = self._prepare_inputs(scheduler_output) 
- 
Then, you can simplify the logic for updating self.kv_scales_calculatedas suggested below.
# Mark KV scales as calculated if they were computed in this pass.
if enable_kv_scales_calculation:
    self.kv_scales_calculated = TrueThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] | ||
| if self.dcp_world_size > 1 | ||
| else None, | ||
| enable_kv_scales_calculation=enable_kv_scales_calculation, | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
  Propagate KV scale flag to per-layer metadata
The new enable_kv_scales_calculation flag is set only on CommonAttentionMetadata here, but the per-layer metadata objects that Attention.forward actually receives are produced later by builders (split_attn_metadata and the various *AttentionMetadataBuilder.build) without copying this attribute. As a result getattr(attn_metadata, "enable_kv_scales_calculation", False) in vllm/attention/layer.py remains False and the KV scale calculation path never runs, so the bug this change is meant to fix still occurs whenever KV scales are enabled. The flag needs to be attached to the per-layer metadata returned by the builders (and preserved when splitting) so layers can see it.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the fix, I got a suggestion to further simplify the logic
        
          
                vllm/v1/attention/backends/utils.py
              
                Outdated
          
        
      | dcp_local_seq_lens: torch.Tensor | None = None | ||
| """Sequence lengths of the local rank in decode context parallelism world""" | ||
|  | ||
| enable_kv_scales_calculation: bool = False | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make sure this is removed from backend-specific attention metadata classes, if anywhere?
        
          
                vllm/v1/worker/gpu_model_runner.py
              
                Outdated
          
        
      | self.is_multimodal_pruning_enabled = False | ||
| self.max_model_len = model_config.max_model_len | ||
|  | ||
| self.kv_scales_calculated = False | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.kv_scales_calculated = False | |
| # Always set to false after the first forward pass | |
| self.calculate_kv_scales = self.cache_config.calculate_kv_scales | 
        
          
                vllm/v1/worker/gpu_model_runner.py
              
                Outdated
          
        
      | # Determine if we need to calculate KV scales on this forward pass. | ||
| # Only True on the first pass when calculate_kv_scales is enabled. | ||
| enable_kv_scales_calculation = ( | ||
| self.cache_config.calculate_kv_scales | ||
| and not self.kv_scales_calculated) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Determine if we need to calculate KV scales on this forward pass. | |
| # Only True on the first pass when calculate_kv_scales is enabled. | |
| enable_kv_scales_calculation = ( | |
| self.cache_config.calculate_kv_scales | |
| and not self.kv_scales_calculated) | 
        
          
                vllm/v1/worker/gpu_model_runner.py
              
                Outdated
          
        
      | dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] | ||
| if self.dcp_world_size > 1 | ||
| else None, | ||
| enable_kv_scales_calculation=enable_kv_scales_calculation, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| enable_kv_scales_calculation=enable_kv_scales_calculation, | |
| enable_kv_scales_calculation=self.calculate_kv_scales, | 
        
          
                vllm/v1/worker/gpu_model_runner.py
              
                Outdated
          
        
      | if (self.cache_config.calculate_kv_scales | ||
| and not self.kv_scales_calculated): | ||
| self.kv_scales_calculated = True | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even easier:
| if (self.cache_config.calculate_kv_scales | |
| and not self.kv_scales_calculated): | |
| self.kv_scales_calculated = True | |
| self.calculate_kv_scales = False | 
| Could you just sanity check with a deepseek model that it works with MLA too? | 
Purpose:
Resolves bug #27102 .
Test Plan:
Throughput
E2E Correctness
Server with KV scale calculation ON (remove --calculate-kv-scales flag for OFF case):
lm_eval
Throughput results
Main:calculate_kv_scales=False | enforce_eager=False
Main:calculate_kv_scales=True | enforce_eager=False
Main: calulate_kv_scales=True | enforce_eager=True
PR: calculate_kv_scales=False | enforce_eager=False
PR: calculate_kv_scales=True | enforce_eager=False
PR: calulate_kv_scales=True | enforce_eager=True
E2E GSM8K Accuracy Results
Main: calulate_kv_scales=False | enforce_eager=False
Main: calulate_kv_scales=True | enforce_eager=False
Main: calulate_kv_scales=True | enforce_eager=True
PR: calculate_kv_scales=False | enforce_eager=False
PR: calculate_kv_scales=True | enforce_eager=False
PR: calulate_kv_scales=True | enforce_eager=True