Skip to content

✨ Add vLLM guided decoding support to GRPO Trainer #2811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,39 @@ def test_training_vllm_and_peft(self):
elif "base_layer" not in n and "original_module" not in n:
# We expect the peft params to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")

@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
6 changes: 6 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class GRPOConfig(TrainingArguments):
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.

> Parameters that control the training

Expand Down Expand Up @@ -201,6 +203,10 @@ class GRPOConfig(TrainingArguments):
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)

# Parameters that control the training
learning_rate: float = field(
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -412,9 +413,18 @@ def data_collator(features): # No data collation is needed in GRPO
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
)

# Guided decoding, if enabled
if args.vllm_guided_decoding_regex is not None:
guided_decoding = GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex)
else:
guided_decoding = None

# Sampling parameters
self.sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=self.max_completion_length,
guided_decoding=guided_decoding,
n=args.num_generations,
)

Expand Down