Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
export VLLM_USE_V1=1
export VLLM_USE_TRITON_FLASH_ATTN=0
export VLLM_USE_TRITON_FLASH_ATTN=1 # use triton mha
# export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_RPC_TIMEOUT=1800000
export VLLM_ROCM_USE_AITER=1
export VLLM_ROCM_USE_AITER_MHA=0
export VLLM_ROCM_USE_AITER_MLA=1
export VLLM_ROCM_USE_AITER_MLA=0 # use triton mha
export VLLM_ROCM_USE_AITER_MOE=1
export VLLM_ROCM_USE_TRITON_ROPE=1 # add for acc
export VLLM_DISABLE_COMPILE_CACHE=1
# FIXME: for now disable fp4 asm gemm because of running issue
export VLLM_ROCM_USE_AITER_FP4_ASM_GEMM=0
#export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # for now disable
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # disable for acc

export TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1
export TRITON_HIP_USE_ASYNC_COPY=1
Expand All @@ -37,11 +37,12 @@ vllm serve $model_path \
--trust-remote-code \
--no-enable-prefix-caching \
--disable-log-requests \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--gpu_memory_utilization 0.8 \
--enforce-eager \
--gpu_memory_utilization 0.7 \
--async-scheduling \
--block-size 16 \
--load-format fastsafetensors \
--seed 123 2>&1 | tee log.server.log &

# --enforce-eager \
# --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
# --enable-expert-parallel \
17 changes: 12 additions & 5 deletions vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,21 @@ def _flash_attn_varlen_diff_headdims(
q, k, v, softmax_scale=softmax_scale, **kwargs
)
else:
return super()._flash_attn_varlen_diff_headdims(
q,
k,
v,
return_softmax_lse=return_softmax_lse,
from aiter.ops.triton.mha import flash_attn_varlen_func

result = flash_attn_varlen_func(
q=q,
k=k,
v=v,
return_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
if type(result) is tuple and return_softmax_lse:
output, lse = result
lse = lse.T.contiguous()
return (output, lse)
return result

def _forward_decode(
self,
Expand Down