Skip to content

Conversation

@tjtanaa
Copy link
Collaborator

@tjtanaa tjtanaa commented Nov 1, 2025

Purpose

This is an implementation of the RFC #27821

1. It decouples the ViT Attention Backend selection logic from the Text Attention Backend.

Redefining _MHA_Backend following the discussion in https://github.com/vllm-project/vllm/pull/27061/files#r2443909604 .

class _MHA_Backend(enum.Enum):
    VLLM_FLASH_ATTN = enum.auto()  # CUDA-only
    FLASH_ATTN = enum.auto()  # CUDA/ROCm
    XFORMERS = enum.auto()  # CUDA
    ROCM_AITER_FA = enum.auto()  # ROCM-only
    TORCH_SDPA = enum.auto()  # CUDA/ROCm/TPU/XPU/CPU
    PALLAS = enum.auto() # TPU only

2. Make ViT Backend selection Platform specific information

Make sure that the ViT attention is a platform specific. We should determine the supported ViT backend through platform interface. Any overriding logic should be performed in the platform interface as only the platform interface knew what other ViT backend it can fallback to. We should avoid adding overriding logic to the model.py files.

So, we allow get_vit_attn_backend in the platform interface has to be able to access the --mm-encoder-attn-backend.

In the platform interface, we should only return _MHA_Backend, we should not return the functions. The flash attention functions should only be returned through maybe_get_vit_flash_attn_backend . If the model only supports a specific type of attention, the ViT overriding logic will be implemented explicitly in the model definition file (model.py).

class Platform:

    ...

    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]:
        from vllm.attention.backends.registry import _MHA_Backend

        return [
            _MHA_Backend.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["_MHA_Backend"] = None,
    ) -> "_MHA_Backend":
        # ViT Attention should be checked and override
        # in the platform-specific implementation.
        # we should not override this in any other places,
        # like the model_executor/models/<model_name>.py

        # So the steps are:
        # 1. Check if the backend is None or not:
        #    a. If not, check if the backend is supported by the platform.
        #    b. If None, continue to the default selection logic.

        # Import _Backend here to avoid circular import.
        from vllm.attention.backends.registry import _MHA_Backend

        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention"
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using default backend {_MHA_Backend.TORCH_SDPA} for vit attention"
        )
        return _MHA_Backend.TORCH_SDPA

Test Plan

Add unit tests to tests:

  • Default ViT selection logic: None
  • All platform support ViT backend: current_platform.get_supported_vit_attn_backends()

Setup unit tests for all affected model:

  • tests/models/multimodal/generation/test_ovis2_5.py
  • tests/models/multimodal/generation/test_maverick.py
  • tests/models/multimodal/generation/test_dots_ocr.py
  • tests/models/multimodal/generation/test_keye.py
  • tests/models/multimodal/generation/test_qwen2_5_vl.py
  • tests/models/multimodal/generation/test_qwen2_vl.py
  • tests/models/multimodal/generation/test_ovis2_5.py
  • tests/models/multimodal/generation/test_qwen3_omni_moe_thinker.py
  • tests/models/multimodal/generation/test_ernie45_vl.py
  • tests/models/multimodal/generation/test_glm4_1v.py

Test Result

All unit tests passed


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
@mergify mergify bot added multi-modality Related to multi-modality (#4194) qwen Related to Qwen models rocm Related to AMD ROCm tpu Related to Google TPUs labels Nov 1, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great refactoring that decouples the ViT attention backend selection from the text attention backend and moves platform-specific logic to the platform interface. This significantly improves code organization and maintainability. I've identified a couple of critical issues that need to be addressed to ensure the new logic works as intended.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@tjtanaa
Copy link
Collaborator Author

tjtanaa commented Nov 1, 2025

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. 👍

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@jikunshang
Copy link
Collaborator

thanks for refactor! cc @yma11 @faaany please take a review & test for xpu.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this amazing effort! I will review this PR tonight!

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
@tjtanaa
Copy link
Collaborator Author

tjtanaa commented Nov 4, 2025

@ywang96 @DarkLight1337 I have rebased it with upstream again. It contains the enablement of torch.compile for qwen2.5vl model.

I would like to suggest the further enablement of torch.compile to all other models and the clean up of maybe_get_vit_flash_attn_backend e.g.

self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
            self.attn_backend,
        )

will be done in another PR after this one.

Copy link
Collaborator

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) needs-rebase nvidia qwen Related to Qwen models rocm Related to AMD ROCm tpu Related to Google TPUs

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants