You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When a sampled trajectory encounters some string in the list, substitute the next token by EOS and continue.
Motivation
Current GRPOTrainer is mainly designed for SFT models that can output an EOS token appropriately. In order to perform R1-Zero-like training on base models, a stop string list is needed.
I have tried to directly add tokenizer=self.processing_class in generation_config but it seems to cause all_reduce issues
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/trainer.py", line 2171, in train
[rank6]: return inner_training_loop(
[rank6]: ^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/trainer.py", line 2531, in _inner_training_loop
[rank6]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/trainer.py", line 3675, in training_step
[rank6]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/home/user/lhx/llama_marcel/utils/my_grpo.py", line 94, in compute_loss
[rank6]: prompt_completion_ids = unwrapped_model.generate(
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank6]: return func(*args, **kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2255, in generate
[rank6]: result = self._sample(
[rank6]: ^^^^^^^^^^^^^
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/generation/utils.py", line 3243, in _sample
[rank6]: while self._has_unfinished_sequences(
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2449, in _has_unfinished_sequences
[rank6]: dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank6]: return func(*args, **kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/home/user/miniconda3/envs/myenv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2501, in all_reduce
[rank6]: work = group.allreduce([tensor], opts)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: RuntimeError: No backend type associated with device type cpu
I guess for use_vllm mode it is as simple as add stop=stop_strings for sampling_params, but I have not verified it yet.
I will try to submit a PR as I find a way to solve the issue and verify my code. Appreciate any support from you.
The text was updated successfully, but these errors were encountered:
Feature request
Add a --stop_stings argument for GRPOConfig.
When a sampled trajectory encounters some string in the list, substitute the next token by EOS and continue.
Motivation
Current GRPOTrainer is mainly designed for SFT models that can output an EOS token appropriately. In order to perform R1-Zero-like training on base models, a stop string list is needed.
For more details, see this open-r1 issue.
Your contribution
I have tried to directly add tokenizer=self.processing_class in generation_config but it seems to cause all_reduce issues
I guess for use_vllm mode it is as simple as add stop=stop_strings for sampling_params, but I have not verified it yet.
I will try to submit a PR as I find a way to solve the issue and verify my code. Appreciate any support from you.
The text was updated successfully, but these errors were encountered: