-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FP8][Kernel] Dynamic kv cache scaling factors computation #11906
base: main
Are you sure you want to change the base?
[FP8][Kernel] Dynamic kv cache scaling factors computation #11906
Conversation
…ching (#317) * Changed _k_scale and _v_scale to tensors * fixed rocm paged attention with tensor kv scales * Added on the fly scale factor calculation * trying to fix attn metadata * fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py * Changed K and V scale constants * Removed unneeded comment * Changes to pass format.sh, also fixed lingering k_scale/v_scale : float * Fix for TP > 1 * Ran format.sh * Removed legacy kv_scale loading from the json file * Removed the outdated kv cache docs * Revert some unwanted changes --------- Co-authored-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Gregory Shtrasberg <[email protected]>
* Using tensors in the explicit cache function calls from mllama implementation * Properly creating the tensor Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This PR deprecates loading kv cache scales from json in favor of adding the option to dynamically compute them based on the first real input to the attention layer.
Our tests showed that the dynamic range computed based on the first input to each layer is representative of the entire model, and the accuracy is comparable with scaling factors computed using Quark quantizer (such as in HF amd/*-FP8-KV models)
Accuracy measured using the P3L benchmark that allows measuring accuracy on decode steps, using the data in the kv cache
K and V scale parameters are made on-device tensors in order to allow changing their values after the graph has been captured. This also lays the foundation to using per-channel quantization with tensor-like scales.
The effect is most visible on models with dynamic value ranges outside of the scope of fp8e4m3, such as Quen2 7B:
Using dynamic calculation reduces the PPL score from 34.84 to 22.62
On LLama based models the improvement is much smaller, due to the fact that identity scales work just as well, but still can be in single digit percents, on par with using the scales from a quantized model