Skip to content

Commit a80d237

Browse files
authored
fix: [https://nvbugspro.nvidia.com/bug/5243482] If FlashMLA is used, the existence of FMHA based MLA kernels should not be checked. (#3862)
* Add mIsGenerationMLA to differentiate ctx and gen MLA in AttentionOp. For Generation MLA, if FlashMLA is used, do not check the existence of FMHA based MLA kernel. Signed-off-by: Bo Li <[email protected]> * Run pre-commit. Signed-off-by: Bo Li <[email protected]> * Fix compile error. Signed-off-by: Bo Li <[email protected]> --------- Signed-off-by: Bo Li <[email protected]>
1 parent afb7d3a commit a80d237

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

+12-8
Original file line numberDiff line numberDiff line change
@@ -2416,7 +2416,7 @@ int AttentionOp::initialize() noexcept
24162416
fmhaParams.numTokensPerBlock = mTokensPerBlock;
24172417
fmhaParams.headSize = mHeadSize;
24182418
fmhaParams.headSizeV = mHeadSize;
2419-
if (mIsMLAEnabled)
2419+
if (mIsMLAEnabled && !mIsGenerationMLA)
24202420
{
24212421
// Context attention of MLA is different
24222422
fmhaParams.numKvHeads = mNumHeads;
@@ -2476,10 +2476,9 @@ int AttentionOp::initialize() noexcept
24762476
// Instantiate the mTllmGenFMHARunner used for MLA
24772477
mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType));
24782478
}
2479-
else
2479+
else if (mIsGenerationMLA && !mUseGenFlashMLA)
24802480
{
2481-
// Construct the fmha runner.
2482-
// FP8 Generation MLA also uses context FMHA.
2481+
// Construct the fmha runner for generation.
24832482
if (mFP8GenerationMLA)
24842483
{
24852484
data_type = DATA_TYPE_E4M3;
@@ -2524,13 +2523,17 @@ int AttentionOp::initialize() noexcept
25242523
"Deepseek should be supported by fmha in generation part.");
25252524
}
25262525
}
2527-
2528-
TLLM_CHECK_WITH_INFO(
2529-
mFmhaDispatcher->isSupported(), "Deepseek should be supported by fmha in context part.");
2526+
if (!mIsGenerationMLA)
2527+
{
2528+
TLLM_CHECK_WITH_INFO(
2529+
mFmhaDispatcher->isSupported(), "Deepseek should be supported by fmha in context part.");
2530+
}
25302531
}
25312532

25322533
// Fall back to unfused MHA kernels if not supported.
2533-
mEnableContextFMHA = mFmhaDispatcher->isSupported();
2534+
// Generation MLA reuses the context FMHA code path so set mEnableContextFMHA to true.
2535+
// However, do not check mFmhaDispatcher which is not used for generation MLA.
2536+
mEnableContextFMHA = mIsGenerationMLA || mFmhaDispatcher->isSupported();
25342537

25352538
// Only FMHA supports custom mask currently.
25362539
TLLM_CHECK_WITH_INFO(
@@ -2697,6 +2700,7 @@ std::string AttentionOp::toString() const
26972700
ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl;
26982701
ss << "mSM: " << mSM << std::endl;
26992702
ss << "mUseTllmGen: " << mUseTllmGen << std::endl;
2703+
ss << "mIsGenerationMLA: " << std::boolalpha << mIsGenerationMLA << std::endl;
27002704
ss << "mUseGenFlashMLA: " << mUseGenFlashMLA << std::endl;
27012705
ss << "mMultiProcessorCount: " << mMultiProcessorCount << std::endl;
27022706
ss << "mMaxSharedMemoryPerBlockOptin: " << mMaxSharedMemoryPerBlockOptin << std::endl;

cpp/tensorrt_llm/common/attentionOp.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ class AttentionOp
379379
bool mSpecDecodingIsGenerationLengthVariable = false;
380380
int32_t mSpecDecodingMaxGenerationLength = 1;
381381
bool mIsMLAEnabled = false;
382+
bool mIsGenerationMLA = false;
382383
bool mUseGenFlashMLA = false;
383384
tensorrt_llm::kernels::MlaMetaParams mMLAParams;
384385
int mCpSize = 1;
@@ -422,10 +423,10 @@ class AttentionOp
422423
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
423424
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
424425
mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable,
425-
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank,
426-
mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
427-
mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
428-
mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
426+
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(),
427+
mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank,
428+
mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode,
429+
mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
429430
};
430431

431432
private:

cpp/tensorrt_llm/thop/attentionOp.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,18 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
452452
{
453453
TLLM_CHECK(host_kv_cache_pool_mapping.has_value());
454454
int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0);
455+
455456
op->mIsMLAEnabled = true;
456-
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
457-
// only enable flash mla in the generation phase on sm90 and tokens_per_block == 64
458-
op->mUseGenFlashMLA = tensorrt_llm::common::getSMVersion() == 90 && tokens_per_block == 64;
459457
op->mMLAParams = {static_cast<int>(q_lora_rank.value()), static_cast<int>(kv_lora_rank.value()),
460458
static_cast<int>(qk_nope_head_dim.value()), static_cast<int>(qk_rope_head_dim.value()),
461459
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
462460
static_cast<int>(layer_num)};
461+
462+
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
463+
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
464+
// only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64
465+
op->mUseGenFlashMLA = tensorrt_llm::common::getSMVersion() == 90 && tokens_per_block == 64;
466+
463467
// The following two parameters are used to compute kvcache related parameters such as kvcache block_size. So
464468
// they need to be set to 1 and 512 + 64 for both context and generation. For MLA attention kernel configs,
465469
// mNumKVHeads/mHeadSize are overwritten in common/attentionOp.cpp.

0 commit comments

Comments
 (0)