-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Modified GRPOTrainer to accumulate gradient within a single training batch #3288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thanks @jarrelscy I understand the motivation. Just for clarification, if
then why not using gradient accumulation? Is it because the generation will also be done on smaller batches, then makes things slower? |
Hi @qgallouedec as @JamesBowerXanda pointed out in here, the quality of the loss depends on the group size. In this paper they point that you need a large group size to approximate the expected reward normalised by the standard deviation of the reward of an output sampled from the previous policy. In GRPO each generation is assigned a relative advantage against other generations, so if the group size is small, this can lead to erratic losses. In gradient accumulation (per batch), we are still comparing the advantage of each generation against other generations within that batch. |
FYI, now you can pass a group as large as |
Closing this as I believe the motivation behind this PR has been addressed by #3283 |
@qgallouedec this PR is not exactly the same as 3283 - its akin to the comment here on 3283, which states that this functionality is not implemented in 3283. |
Ok, sorry for the misjudgement, so I'm reopening the PR. |
Hello, any update for this? Have you merge it to the master branch? |
@jiangix-paper there are some new changes in the trl main branch which I think are not compatible - the entropy masking implementation. i've just done a merge, feel free to try cloning and testing it |
What does this PR do?
GRPOTrainer calculates advantages and then calculates loss per completion. Currently this is all done within a single batch which can take a lot of memory. Just like with gradient accumulation, we can call .backwards on the loss for each completion separately. This PR does so by introducing a new parameter into GRPOConfig called num_generations_chunks, of which num_generations needs to be a multiple of. Doing so will cause loss.backward to be called per num_generations_chunks number of completions.
Example usage:
Fixes # 3017
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.