Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support context parallel training with ring-flash-attention #33467

Open
zhuzilin opened this issue Sep 13, 2024 · 9 comments
Open

Support context parallel training with ring-flash-attention #33467

zhuzilin opened this issue Sep 13, 2024 · 9 comments

Comments

@zhuzilin
Copy link

Feature request

Hi, I'm the author of zhuzilin/ring-flash-attention.

I wonder if you are interested in integrating context parallel with zhuzilin/ring-flash-attention, so that user can train llm with long data more efficiently.

Motivation

As openai o1 released, it will probably be common for people to train model with really long cot data. And it will be nice if most model within the transformers library can support training with long context efficiently with certain type of context parallel, i.e. the context length scale linearly with the number of GPUs.

The 3 existing context parallel methods are the deepspeed ulysses, ring attention and the one proposed in the llama3 tech report. The deepspeed ulysses will be limited by the number of kv heads (the maximum context length can be num_head_kv * seq_length_per_gpu), which makes it a little unfriendly to GQA models. So it will be great if the transformers library could support the one or both of the other 2 context parallel methods.

And both ring attention and the llama3 strategy are supported with flash attention in zhuzilin/ring-flash-attention, whose correctness has been proved by jzhang38/EasyContext. The library basically has the same api as flash attention, and hides the communication required from its user to make it a easy substitution from any origin flash attention api callsite.

Therefore, I believe it will be easy to support the context parallel with zhuzilin/ring-flash-attention. For example, we could have different branch in modeling_flash_attention_utils._flash_attention_forward.

Your contribution

I'd love to help if you have interests :)

@zhuzilin zhuzilin added the Feature request Request for a new feature label Sep 13, 2024
@LysandreJik
Copy link
Member

cc @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Hey hey! I think similarly to the recent integration of liger kernels this would make sense for sure to have a soft dependency, and monkey patch the flash_attn_forward 🤗 happy to review a pr like this!

@ZetangForward
Copy link

Hi, that's very nice :) Has this feature been integrated into the Huggingface library now?

@zhuzilin
Copy link
Author

zhuzilin commented Oct 8, 2024

Hmm... After a brief look through, it seems hard to give a clean PR to transformers repo for ring flash attn... There are main obstacles:

  1. For really long data, different rank may need to share the same data, e.g. 4 gpu sharing the same batch with only 1 sample within. In that case, we need to customize the num_replicas and rank passed to samplers, which is currently managed in accelerate. And we could not pass config to accelerate to change the behavior at the moment.
  2. The ring attention need to manage a new process group for the GPUs that share the same batch, so that they can communicate with each other during the attention calculation. There is no clear location where to put that. (we can set a global variable for the process group, but I'm not sure if that will broke some config branch of the Trainer....)
  3. We cannot only change the accerlate repo, because we need a new way to manage the data in Trainer, something like DataCollatorWithFlattening.

I'm afraid we have to postpone this feature after some prerequisites are ready.

@zhuzilin
Copy link
Author

zhuzilin commented Oct 9, 2024

Just found out that the deepspeed team is adding ulysses. We can wait until they land the feature, as ring attention and deepspeed ulysses will share many prerequisites~

@casper-hansen
Copy link

DeepSpeed added Ulysses. Any plans to integrate this (highly needed for reasoning models) @ArthurZucker @zhuzilin

https://www.deepspeed.ai/tutorials/ds-sequence/

@zhuzilin
Copy link
Author

zhuzilin commented Feb 5, 2025

DeepSpeed added Ulysses.

@casper-hansen That's great! Could you point to the PR that integrate ulysses into huggingface transformers/accelerator?

@casper-hansen
Copy link

@zhuzilin it's not integrated into huggingface. it's only in deepspeed

@ArthurZucker
Copy link
Collaborator

Pr (#35301) was unfortunately closed! Anyone can work on this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants