Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,32 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
# Install some basic utilities
RUN apt-get update -q -y && apt-get install -q -y \
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \
apt-transport-https ca-certificates wget curl
apt-transport-https ca-certificates wget curl git

# Remove sccache
RUN python3 -m pip install --upgrade pip
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"

# BUILD FA only if both FLASH_ATTENTION_TRITON_AMD_ENABLE and GPU_ARCHS are passed/declared and non-empty
ARG FLASH_ATTENTION_TRITON_AMD_ENABLE
ARG GPU_ARCHS

RUN if [ -n "${FLASH_ATTENTION_TRITON_AMD_ENABLE}" ] && [ -n "${GPU_ARCHS}" ]; then \
echo "Compiling Flash Attention with GPU_ARCHS=${GPU_ARCHS}..." ; \
export FLASH_ATTENTION_TRITON_AMD_ENABLE="${FLASH_ATTENTION_TRITON_AMD_ENABLE}"; \
export GPU_ARCHS="${GPU_ARCHS}"; \
git clone --single-branch --branch main_perf https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& python3 setup.py install \
&& cd .. \
&& rm -rf flash-attention ; \
else \
echo "Skipping Flash Attention compilation (FLASH_ATTENTION_TRITON_AMD_ENABLE and/or GPU_ARCHS not set)." ; \
fi

ENV FLASH_ATTENTION_TRITON_AMD_ENABLE=${FLASH_ATTENTION_TRITON_AMD_ENABLE}
ENV GPU_ARCHS=${GPU_ARCHS}

ARG COMMON_WORKDIR
WORKDIR ${COMMON_WORKDIR}

Expand All @@ -27,9 +49,9 @@ FROM base AS fetch_vllm_1
ARG VLLM_REPO="https://github.com/vllm-project/vllm.git"
ARG VLLM_BRANCH="main"
ONBUILD RUN git clone ${VLLM_REPO} \
&& cd vllm \
&& git fetch -v --prune -- origin ${VLLM_BRANCH} \
&& git checkout FETCH_HEAD \
&& cd vllm \
&& git fetch -v --prune -- origin ${VLLM_BRANCH} \
&& git checkout FETCH_HEAD \
&& if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \
git remote add upstream "https://github.com/vllm-project/vllm.git" \
&& git fetch upstream ; fi
Expand Down
17 changes: 15 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@
)

if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_gfx9, on_gfx1x
else:
on_gfx9 = lambda *args, **kwargs: False
on_gfx1x = lambda *args, **kwargs: False


FP8_DTYPE = current_platform.fp8_dtype()
Expand Down Expand Up @@ -103,13 +104,25 @@ def maybe_get_vit_flash_attn_backend(
use_upstream_fa: bool,
attn_backend_override: _Backend | None = None,
) -> tuple[_Backend, Callable | None]:
import os
from importlib.util import find_spec

if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = _Backend.ROCM_AITER_FA

elif (
os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") == "TRUE"
and os.environ.get("GPU_ARCHS") == "gfx1100"
and find_spec("flash_attn") is not None
and on_gfx1x()
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True

elif (
check_upstream_fa_availability(torch.get_default_dtype())
and on_gfx9()
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
Expand Down
19 changes: 14 additions & 5 deletions vllm/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@

import einops
import torch
import vllm.envs as envs

from vllm.utils.torch_utils import direct_register_custom_op
from vllm.platforms import current_platform

if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9, on_gfx1x
else:
on_gfx9 = lambda *args, **kwargs: False


def xformers_attn_seqlens_wrapper(
Expand Down Expand Up @@ -61,10 +68,14 @@ def flash_attn_maxseqlen_wrapper(
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I am not sure what the ask is on this PR, but from my end we should feel free to change these shims however we need.

These custom ops are purely to preserve traceability of the ViT component, and the signatures are designed like this because we went from selecting attention on attrs of the model (self) to needing an independent function without self parameter.

I do want to voice a design consideration on this backend selection logic as a whole though - to me, it would seem better if we could use just pass attn_fnlambda's directly as opposed to some backend enum then doing the function selection later. I wonder what is preventing us from doing this in the code today? (traceability, etc)

cc @ywang96 who may have more context on this

Copy link
Member

@ywang96 ywang96 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would seem better if we could use just pass attn_fnlambda's directly as opposed to some backend enum then doing the function selection later. I wonder what is preventing us from doing this in the code today?

From the user perspective it's cleaner to just pass in an enum (e.g, --mm-encoder-attn-backend TORCH_SDPA) and it's better for us to control this over passing an entire free-form attn implementation, but I agree that enum -> attn_fn can be done at the level of init time of XXXVisionTransformer and we pass the resolved attn_fn as an input downstream to XXXVisionAttention. Does that align with what's on your mind?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lucaskabela @ywang96

I have an RFC #27821 that is proposing the same idea of enum -> attn_fn is happening in the XXXVisionTransformer .

However, in this RFC taking into account that many of the VL models share the same logic as qwen2_5_vl.py, it can be abstract out further that the Overriding logic should be handled by platform as only platform knows that backend it can support.

So the maybe_get_vit_flash_attn_backend will solely responsible for enum -> attn_fn mapping rather than including overriding logic. (maybe_get_vit_flash_attn_backend will be renamed to a new name matching its role).

if (
current_platform.is_rocm()
and on_gfx9()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_MHA
):
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
Expand Down Expand Up @@ -96,7 +107,6 @@ def flash_attn_maxseqlen_wrapper_fake(
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
Expand All @@ -117,9 +127,8 @@ def vit_flash_attn_wrapper(
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
q, k, v, cu_seqlens, max_seqlen, batch_size, use_upstream_fa
)
9 changes: 8 additions & 1 deletion vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,21 @@ def forward(
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
from importlib.util import find_spec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire logic seems unnecessary here because:

  1. max_seqlen is already a tensor (type hints FTW)
  2. use_upstream_fa is set in Qwen2_5_VisionTransformer/upstream in init. We should just modify the logic there if needed as opposed to here


if not isinstance(max_seqlen, torch.Tensor):
max_seqlen = torch.tensor(
max_seqlen, device=q.device, dtype=torch.int32
)
self.use_upstream_fa = find_spec("flash_attn") is not None

context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == _Backend.TORCH_SDPA:
Expand Down
14 changes: 14 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"0x74b9": "AMD_Instinct_MI325X", # MI325X VF
"0x74a9": "AMD_Instinct_MI300X_HF",
"0x74bd": "AMD_Instinct_MI300X_HF",
"ox744c": "AMD_7900XTX_RDNA3"
}

# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
Expand Down Expand Up @@ -202,6 +203,7 @@ class RocmPlatform(Platform):

@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
import os
from importlib.util import find_spec

from vllm.attention.backends.registry import _Backend
Expand All @@ -212,6 +214,18 @@ def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
if on_gfx9() and find_spec("flash_attn") is not None:
return _Backend.FLASH_ATTN

if (
os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") == "TRUE"
and os.environ.get("GPU_ARCHS") == "gfx1100"
and find_spec("flash_attn") is not None
and on_gfx1x()
):
logger.info(
"Using ViT FlashAttention (upstream) on V1 engine (gfx1x / RDNA3)."
)
return _Backend.FLASH_ATTN

logger.info("Using Vit TORCH_SDPA V1 engine")
return _Backend.TORCH_SDPA

@classmethod
Expand Down
Loading