Skip to content

Commit 46869d6

Browse files
Hao Luhlu1
Hao Lu
authored andcommitted
Fix fp8 kvcache
Signed-off-by: Hao Lu <[email protected]@users.noreply.github.com>
1 parent 8a994d8 commit 46869d6

File tree

7 files changed

+44
-35
lines changed

7 files changed

+44
-35
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,9 @@ def __init__(
422422
def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
423423
self.quant_config = new_quant_config
424424
self.has_fp8_kv_cache = False
425-
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
426-
):
427-
quant_mode = self.quant_config.layer_quant_mode
428-
if quant_mode.has_fp8_kv_cache():
429-
self.has_fp8_kv_cache = True
425+
if self.quant_config:
426+
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
427+
)
430428

431429
def forward(self,
432430
q: torch.Tensor,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def __init__(
8080
head_dim (int): The size of each attention head (hidden_size // num_heads).
8181
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
8282
pos_embd_params (PositionalEmbeddingParams): Optional parameters defining how positional embedding should be applied.
83-
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
8483
"""
8584
rope_params = None
8685
if pos_embd_params is not None:
@@ -126,7 +125,7 @@ def __init__(
126125
self.kwargs = {}
127126
self.kwargs.update(kwargs)
128127

129-
def create_weights(self, quant_config: Optional[QuantConfig] = None):
128+
def update_quant_config(self, quant_config: Optional[QuantConfig] = None):
130129
quant_config = quant_config or QuantConfig()
131130
self.quant_mode = int(quant_config.layer_quant_mode)
132131

@@ -623,16 +622,17 @@ def __init__(
623622

624623
def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
625624
self.quant_config = new_quant_config
626-
self.wrapper.create_weights(self.quant_config)
625+
self.wrapper.update_quant_config(self.quant_config)
627626

628627
self.has_fp8_qdq = self.has_fp8_kv_cache = self.has_nvfp4 = False
629628
if self.quant_config is not None:
629+
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
630+
)
631+
630632
self.has_fp8_qdq = self.quant_config.layer_quant_mode.has_fp8_qdq()
631633
self.has_fp8_block_wise = self.quant_config.layer_quant_mode.has_fp8_block_scales(
632634
)
633635
self.has_nvfp4 = self.quant_config.layer_quant_mode.has_nvfp4()
634-
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
635-
)
636636
self.has_nvfp4 = self.quant_config.layer_quant_mode.has_nvfp4()
637637

638638
def forward(
@@ -662,15 +662,6 @@ def forward(
662662
or metadata.runtime_features.has_speculative_draft_tokens
663663
) if metadata.runtime_features else False
664664

665-
if use_paged_context_fmha and self.has_fp8_kv_cache:
666-
# NOTE: W4A8_AWQ can be included too, exclude for now since
667-
# we don't use int4 in PyTorch
668-
if not (self.has_fp8_qdq or self.has_nvfp4
669-
or self.has_fp8_block_wise):
670-
raise RuntimeError(
671-
"When FP8 KV cache is being used, paged context FMHA cannot be used without "
672-
"FP8 attention.")
673-
674665
num_seqs = metadata.num_seqs
675666
self.wrapper.plan(
676667
tokens_per_block=metadata.tokens_per_block,

tensorrt_llm/_torch/model_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,18 @@ def from_pretrained(cls,
108108
mixed_quant_config_file = model_dir / 'quant_cfg.json'
109109
with open(mixed_quant_config_file) as fm:
110110
mixed_quant_configs = json.load(fm)
111+
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
111112
kv_cache_quant_algo = mixed_quant_configs[
112113
'kv_cache_quant_algo']
113114
mixed_quant_configs = mixed_quant_configs[
114115
'quantized_layers']
116+
if kv_cache_quant_algo is not None and quant_config.kv_cache_quant_algo is not None:
117+
if kv_cache_quant_algo != quant_config.kv_cache_quant_algo:
118+
raise RuntimeError(
119+
f"The kvcache config in 'quant_cfg.json', {kv_cache_quant_algo},"
120+
f"is different from 'hf_quant_config.json', {quant_config.kv_cache_quant_algo}!"
121+
)
122+
kv_cache_quant_algo = kv_cache_quant_algo or quant_config.kv_cache_quant_algo
115123

116124
for layer in mixed_quant_configs:
117125
config = QuantConfig()

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from torch.utils._pytree import tree_any_only
1212
from tqdm import tqdm
1313

14-
from tensorrt_llm.mapping import Mapping
15-
1614
from ...logger import logger
15+
from ...mapping import Mapping
16+
from ...models.modeling_utils import QuantConfig
1717
from ..attention_backend import AttentionMetadata
1818
from ..model_config import ModelConfig, TConfig
1919
from ..modules.attention import Attention
@@ -432,15 +432,21 @@ def __post_init__(self):
432432
# TODO: support MLA
433433

434434
# 2. skip quant for modules in QuantConfig.exclude_modules
435+
# kv_cache_quant_algo takes precedence over exclude_modules
435436
quant_config = self.model_config.quant_config
437+
kv_cache_quant_algo = None
438+
if quant_config:
439+
kv_cache_quant_algo = quant_config.kv_cache_quant_algo
440+
new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
441+
436442
if quant_config is not None:
437443
if quant_config.exclude_modules is not None:
438444
for name, module in self.named_modules():
439445
is_excluded = quant_config.is_module_excluded_from_quantization(
440446
name)
441447
if is_excluded and getattr(module, "quant_config",
442448
None) is not None:
443-
module.quant_config = None
449+
module.quant_config = new_config
444450

445451
for _, module in self.named_modules():
446452
if callable(getattr(module, "create_weights", None)):

tensorrt_llm/_torch/modules/fused_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ def create_weights(self):
410410
self.has_fp8_qdq = False
411411
self.has_fp8_block_scales = False
412412
self.has_nvfp4 = False
413-
if self.quant_config and self.quant_config.quant_mode.has_any_quant():
413+
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
414+
exclude_kv_cache=True):
414415
self.has_any_quant = True
415416
qc = self.quant_config
416417
if qc.quant_mode.has_fp8_qdq():
@@ -1128,7 +1129,8 @@ def load_expert_w2_weight(w2_weight,
11281129
load_expert_w2_weight(w2_weight, self.w2_weight.data[expert_idx],
11291130
is_trtllm_nvfp4)
11301131

1131-
if self.quant_config and self.quant_config.quant_mode.has_any_quant():
1132+
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
1133+
exclude_kv_cache=True):
11321134
if self.quant_config.quant_mode.has_fp8_qdq():
11331135
self._load_fp8_qdq_scales(weights)
11341136
elif self.quant_config.quant_mode.has_nvfp4():

tensorrt_llm/_torch/modules/linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ def create_weights(self):
210210
self.has_fp8_qdq = False
211211
self.has_fp8_block_scales = False
212212
self.has_nvfp4 = False
213-
# only _create_weights, and load quantized weight directly.
213+
214214
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
215-
):
215+
exclude_kv_cache=True):
216216
self.has_any_quant = True
217217
qc = self.quant_config
218218
if qc.layer_quant_mode.has_fp8_qdq():

tensorrt_llm/quantization/mode.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,19 @@ def has_nvfp4(self):
175175
def has_weight_quant(self):
176176
return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS)
177177

178-
def has_any_quant(self):
179-
return self._any(self.INT4_WEIGHTS
180-
| self.INT8_WEIGHTS
181-
| self.ACTIVATIONS
182-
| self.INT8_KV_CACHE | self.FP8_KV_CACHE
183-
| self.NVFP4_KV_CACHE
184-
| self.FP8_QDQ | self.FP8_ROWWISE | self.W4A8_QSERVE
185-
| self.FP8_1x128_128x128
186-
| self.NVFP4)
178+
def has_any_quant(self, exclude_kv_cache: bool = False):
179+
has_quant = self._any(self.INT4_WEIGHTS
180+
| self.INT8_WEIGHTS
181+
| self.ACTIVATIONS
182+
| self.FP8_QDQ | self.FP8_ROWWISE
183+
| self.W4A8_QSERVE
184+
| self.FP8_1x128_128x128
185+
| self.NVFP4)
186+
if exclude_kv_cache:
187+
return has_quant
188+
189+
return has_quant | self._any(self.INT8_KV_CACHE | self.FP8_KV_CACHE
190+
| self.NVFP4_KV_CACHE)
187191

188192
def set_int8_kv_cache(self):
189193
return self | self.INT8_KV_CACHE

0 commit comments

Comments
 (0)