-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support context parallel training with ring-flash-attention #33467
Comments
Hey hey! I think similarly to the recent integration of liger kernels this would make sense for sure to have a soft dependency, and monkey patch the |
Hi, that's very nice :) Has this feature been integrated into the Huggingface library now? |
Hmm... After a brief look through, it seems hard to give a clean PR to transformers repo for ring flash attn... There are main obstacles:
I'm afraid we have to postpone this feature after some prerequisites are ready. |
Just found out that the deepspeed team is adding ulysses. We can wait until they land the feature, as ring attention and deepspeed ulysses will share many prerequisites~ |
DeepSpeed added Ulysses. Any plans to integrate this (highly needed for reasoning models) @ArthurZucker @zhuzilin |
@casper-hansen That's great! Could you point to the PR that integrate ulysses into huggingface transformers/accelerator? |
@zhuzilin it's not integrated into huggingface. it's only in deepspeed |
Pr (#35301) was unfortunately closed! Anyone can work on this! |
Feature request
Hi, I'm the author of zhuzilin/ring-flash-attention.
I wonder if you are interested in integrating context parallel with zhuzilin/ring-flash-attention, so that user can train llm with long data more efficiently.
Motivation
As openai o1 released, it will probably be common for people to train model with really long cot data. And it will be nice if most model within the transformers library can support training with long context efficiently with certain type of context parallel, i.e. the context length scale linearly with the number of GPUs.
The 3 existing context parallel methods are the deepspeed ulysses, ring attention and the one proposed in the llama3 tech report. The deepspeed ulysses will be limited by the number of kv heads (the maximum context length can be
num_head_kv * seq_length_per_gpu
), which makes it a little unfriendly to GQA models. So it will be great if the transformers library could support the one or both of the other 2 context parallel methods.And both ring attention and the llama3 strategy are supported with flash attention in zhuzilin/ring-flash-attention, whose correctness has been proved by jzhang38/EasyContext. The library basically has the same api as flash attention, and hides the communication required from its user to make it a easy substitution from any origin flash attention api callsite.
Therefore, I believe it will be easy to support the context parallel with zhuzilin/ring-flash-attention. For example, we could have different branch in
modeling_flash_attention_utils._flash_attention_forward
.Your contribution
I'd love to help if you have interests :)
The text was updated successfully, but these errors were encountered: