Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why can Flash Attention only be used under the JAX backend? #2143

Open
pass-lin opened this issue Mar 15, 2025 · 3 comments
Open

Why can Flash Attention only be used under the JAX backend? #2143

pass-lin opened this issue Mar 15, 2025 · 3 comments
Assignees
Labels
type:feature New feature or request

Comments

@pass-lin
Copy link
Contributor

def has_flash_attention_support():
    if (
        hasattr(keras.config, "is_flash_attention_enabled")
        and keras.config.backend() == "jax"
    ):
        try:
            from jax.nn import dot_product_attention as dot_product_attention
        except ImportError:
            logging.warning(
                "Flash attention is not supported in your current JAX version. "
                "Please update it by following the official guide: "
                "https://jax.readthedocs.io/en/latest/installation.html"
            )
            return False
        return True
    else:
        return False

This is the source code of Llama. I want to know why we don't set it to be able to use Flash Attention in the case of the Torch backend?

@divyashreepathihalli
Copy link
Collaborator

Hi @pass-lin actually it would work with Pytorch too, if the hardware support it and you would need to disable this check.
FA is not supported in Tensorflow and the support will not be added.

@pass-lin
Copy link
Contributor Author

Hi @pass-lin actually it would work with Pytorch too, if the hardware support it and you would need to disable this check. FA is not supported in Tensorflow and the support will not be added.

I want to know if there's any way to launch Llama's Flash Attention in the Torch backend without modifying the Keras Hub source code. If not, will it be supported in the future?

@pass-lin
Copy link
Contributor Author

pass-lin commented Mar 29, 2025

@divyashreepathihalli
Do you think it would be better to modify it to the following form?

def has_flash_attention_support():
    if hasattr(keras.config, "is_flash_attention_enabled"):
        if keras.config.backend() == "jax":
            try:
                from jax.nn import dot_product_attention as dot_product_attention
            except ImportError:
                logging.warning(
                    "Flash attention is not supported in your current JAX version. "
                    "Please update it by following the official guide: "
                    "https://jax.readthedocs.io/en/latest/installation.html"
                )
                return False
            return True
        elif keras.config.backend() == "torch":
            try:
                from torch.backends.cuda import SDPAParams
                from torch.backends.cuda import can_use_flash_attention
            except ImportError:
                logging.warning(
                        "Flash attention is not supported in your current PyTorch "
                        "version. Please update it by following the official guide: "
                        "https://pytorch.org/get-started/locally/"
                    )
                return False
            return True
            
    else:
        return False

If you think this is okay, I can submit a new PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants