-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why can Flash Attention only be used under the JAX backend? #2143
Comments
Hi @pass-lin actually it would work with Pytorch too, if the hardware support it and you would need to disable this check. |
I want to know if there's any way to launch Llama's Flash Attention in the Torch backend without modifying the Keras Hub source code. If not, will it be supported in the future? |
@divyashreepathihalli def has_flash_attention_support():
if hasattr(keras.config, "is_flash_attention_enabled"):
if keras.config.backend() == "jax":
try:
from jax.nn import dot_product_attention as dot_product_attention
except ImportError:
logging.warning(
"Flash attention is not supported in your current JAX version. "
"Please update it by following the official guide: "
"https://jax.readthedocs.io/en/latest/installation.html"
)
return False
return True
elif keras.config.backend() == "torch":
try:
from torch.backends.cuda import SDPAParams
from torch.backends.cuda import can_use_flash_attention
except ImportError:
logging.warning(
"Flash attention is not supported in your current PyTorch "
"version. Please update it by following the official guide: "
"https://pytorch.org/get-started/locally/"
)
return False
return True
else:
return False If you think this is okay, I can submit a new PR. |
This is the source code of Llama. I want to know why we don't set it to be able to use Flash Attention in the case of the Torch backend?
The text was updated successfully, but these errors were encountered: