Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stop strings list for GRPO #2820

Open
haoxiongliu opened this issue Feb 10, 2025 · 0 comments
Open

Support stop strings list for GRPO #2820

haoxiongliu opened this issue Feb 10, 2025 · 0 comments
Labels
✨ enhancement New feature or request 🏋 GRPO Related to GRPO

Comments

@haoxiongliu
Copy link

haoxiongliu commented Feb 10, 2025

Feature request

Add a --stop_stings argument for GRPOConfig.

When a sampled trajectory encounters some string in the list, substitute the next token by EOS and continue.

Motivation

Current GRPOTrainer is mainly designed for SFT models that can output an EOS token appropriately. In order to perform R1-Zero-like training on base models, a stop string list is needed.

For more details, see this open-r1 issue.

Your contribution

I have tried to directly add tokenizer=self.processing_class in generation_config but it seems to cause all_reduce issues

[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/trainer.py", line 2171, in train
[rank6]:     return inner_training_loop(
[rank6]:            ^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/trainer.py", line 2531, in _inner_training_loop
[rank6]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank6]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/trainer.py", line 3675, in training_step
[rank6]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/user/lhx/llama_marcel/utils/my_grpo.py", line 94, in compute_loss
[rank6]:     prompt_completion_ids = unwrapped_model.generate(
[rank6]:                             ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank6]:     return func(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2255, in generate
[rank6]:     result = self._sample(
[rank6]:              ^^^^^^^^^^^^^
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/generation/utils.py", line 3243, in _sample
[rank6]:     while self._has_unfinished_sequences(
[rank6]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2449, in _has_unfinished_sequences
[rank6]:     dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank6]:     return func(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2501, in all_reduce
[rank6]:     work = group.allreduce([tensor], opts)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: RuntimeError: No backend type associated with device type cpu

I guess for use_vllm mode it is as simple as add stop=stop_strings for sampling_params, but I have not verified it yet.

I will try to submit a PR as I find a way to solve the issue and verify my code. Appreciate any support from you.

@github-actions github-actions bot added ✨ enhancement New feature or request 🏋 GRPO Related to GRPO labels Feb 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
✨ enhancement New feature or request 🏋 GRPO Related to GRPO
Projects
None yet
Development

No branches or pull requests

1 participant