Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote GRPO ref model #2763

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Remote GRPO ref model #2763

wants to merge 5 commits into from

Conversation

edbeeching
Copy link
Collaborator

@edbeeching edbeeching commented Feb 4, 2025

Adds an option to use a remote reference model, hosted on another node. The user can provide the url.
I was originally going to use vllm serve but you would have to decode the ids and encode etc. This seems simpler, but may not be as robust. Speed is not an issue though as it is just a forward pass and not autoregressive generation.
Usage:

# launch the app on on node
python trl/models/remote_model_app.py --model_name deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

# then provide the url to the training script
python trl/models/remote_model_app.py ref_model_url: http://1.2.3.4:8000

Still to do:

  • Multi-GPU training, as the calls to the remote model may conflict with each other
  • Large batches
  • add some better error handling / restarts
  • current implemention assumes the model + activations fit on 1 GPU, which is fine for up to 14B but we will need a different approach for larger models.
  • Create a PR on the open-r1 repo with slurm script and recipe

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@shirinyamani
Copy link
Contributor

if the ref_model exists on another machine (node), I don't fully understand how fsdp can take place ? Just to improve my knowledge I appreciate explanation on how the sync can happen without conflict ?

@Superskyyy
Copy link
Contributor

This would potentially conflict with PR #2684 though, maybe need a note on doc.

@edbeeching
Copy link
Collaborator Author

@shirinyamani
The ref model is fixed, so there is no need to sync weights to it. It could be wrapped with FSDP, although this implementation does not expose that option, we assume it is small enough to fit on one GPU.

Only the model being optimized is sharded in this setting, the ref model is running on another node in order to free memory on the node being used for optimization.

@Superskyyy , good point. Yes I do not think that iterative GRPO is compatible with this option.

@@ -78,6 +78,9 @@ class GRPOConfig(TrainingArguments):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.

> Parameters that control remote models
ref_model_url: str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ref_model_url: str
ref_model_url: str
.... Using a remote ref model isn't compatible with ref model syncing.

@Superskyyy
Copy link
Contributor

@shirinyamani

The ref model is fixed, so there is no need to sync weights to it. It could be wrapped with FSDP, although this implementation does not expose that option, we assume it is small enough to fit on one GPU.

Only the model being optimized is sharded in this setting, the ref model is running on another node in order to free memory on the node being used for optimization.

@Superskyyy , good point. Yes I do not think that iterative GRPO is compatible with this option.

When a more distributed backend is built into the lib it can be solved naturally.

@edbeeching
Copy link
Collaborator Author

edbeeching commented Feb 5, 2025

@Superskyyy
Just pasting a comment from our internal slack below, I should be able to include interactive GRPO in this PR using a remote SDLang endpoint:

I am going refactor it to work with an SDLang endpoint. The reasoning is that SDLang endpoints have two nice features, working directly with the token_ids and the capacity to reload a new model weights from disk.
The weight reload is useful in two settings, iterative GRPO, which is already available in TRL by this PR, and also async GRPO, which may be required in the future to scale up the training.

@Superskyyy
Copy link
Contributor

Superskyyy commented Feb 6, 2025

@Superskyyy Just pasting a comment from our internal slack below, I should be able to include interactive GRPO in this PR using a remote SDLang endpoint:

I am going refactor it to work with an SDLang endpoint. The reasoning is that SDLang endpoints have two nice features, working directly with the token_ids and the capacity to reload a new model weights from disk.
The weight reload is useful in two settings, iterative GRPO, which is already available in TRL by this PR, and also async GRPO, which may be required in the future to scale up the training.

@edbeeching Thanks! I'm planning some further decoupling and efficiency gains, once this is merged I will try to add something on top of it this weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants