Skip to content

Commit 5c9cf20

Browse files
winglianqgallouedec
andauthoredFeb 13, 2025··
👨‍👩‍👧 GRPO + PEFT + vLLM (#2818)
* peft + grpo + vllm * test change * support model alread peft * Update tests/test_grpo_trainer.py --------- Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 8830786 commit 5c9cf20

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed
 

‎tests/test_grpo_trainer.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_training_peft(self):
125125

126126
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
127127

128-
# Check the peft params have changed and the base model params have not changed
128+
# Check that the peft params have changed and the base model params have not changed
129129
for n, param in previous_trainable_params.items():
130130
new_param = trainer.model.get_parameter(n)
131131
if n in base_param_names: # We expect the base model params to be the same
@@ -168,7 +168,7 @@ def test_training_different_reward_model(self):
168168

169169
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
170170

171-
# Check the params have changed
171+
# Check that the params have changed
172172
for n, param in previous_trainable_params.items():
173173
new_param = trainer.model.get_parameter(n)
174174
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@@ -203,7 +203,7 @@ def reward_func(completions, **kwargs):
203203

204204
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
205205

206-
# Check the params have changed
206+
# Check that the params have changed
207207
for n, param in previous_trainable_params.items():
208208
new_param = trainer.model.get_parameter(n)
209209
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@@ -239,7 +239,7 @@ def reward_func(completions, **kwargs):
239239

240240
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
241241

242-
# Check the params have changed
242+
# Check that the params have changed
243243
for n, param in previous_trainable_params.items():
244244
new_param = trainer.model.get_parameter(n)
245245
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@@ -278,7 +278,7 @@ def reward_func2(completions, **kwargs):
278278

279279
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
280280

281-
# Check the params have changed
281+
# Check that the params have changed
282282
for n, param in previous_trainable_params.items():
283283
new_param = trainer.model.get_parameter(n)
284284
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@@ -356,7 +356,7 @@ def reward_func(completions, **kwargs):
356356

357357
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
358358

359-
# Check the params have changed
359+
# Check that the params have changed
360360
for n, param in previous_trainable_params.items():
361361
new_param = trainer.model.get_parameter(n)
362362
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@@ -395,7 +395,7 @@ def reward_func(completions, some_values, **kwargs):
395395

396396
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
397397

398-
# Check the params have changed
398+
# Check that the params have changed
399399
for n, param in previous_trainable_params.items():
400400
new_param = trainer.model.get_parameter(n)
401401
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@@ -416,9 +416,10 @@ def test_training_vllm(self):
416416
report_to="none",
417417
use_vllm=True,
418418
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
419+
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
419420
)
420421
trainer = GRPOTrainer(
421-
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
422+
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny is too small for vLLM
422423
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
423424
args=training_args,
424425
train_dataset=dataset,
@@ -504,6 +505,8 @@ def test_training_with_sync_ref_model(self):
504505
@require_peft
505506
def test_training_vllm_and_peft(self):
506507
"""Test that training works with vLLM for generation."""
508+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM
509+
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
507510
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
508511

509512
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -513,14 +516,22 @@ def test_training_vllm_and_peft(self):
513516
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
514517
num_generations=3, # reduce the number of generations to reduce memory usage
515518
max_completion_length=32, # reduce the completion length to reduce memory usage
516-
use_vllm=True,
517519
report_to="none",
520+
use_vllm=True,
521+
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
522+
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
523+
)
524+
lora_config = LoraConfig(
525+
target_modules="all-linear",
526+
# test with non-default modules as it add extra keys in state_dict tht we need to handle
527+
modules_to_save=["embed_tokens", "lm_head"],
518528
)
519529
trainer = GRPOTrainer(
520-
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
530+
model=model,
521531
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
522532
args=training_args,
523533
train_dataset=dataset,
534+
peft_config=lora_config,
524535
)
525536

526537
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
@@ -529,7 +540,11 @@ def test_training_vllm_and_peft(self):
529540

530541
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
531542

532-
# Check that the params have changed
543+
# Check that the peft params have changed and the base model params have not changed
533544
for n, param in previous_trainable_params.items():
534545
new_param = trainer.model.get_parameter(n)
535-
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
546+
if n in base_param_names: # We expect the base model params to be the same
547+
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
548+
elif "base_layer" not in n and "original_module" not in n:
549+
# We expect the peft params to be different (except for the base layer)
550+
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")

‎trl/trainer/grpo_trainer.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.utils.data
2424
import transformers
25-
from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed
25+
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
2626
from accelerate.utils.other import is_compiled_module
2727
from datasets import Dataset, IterableDataset
2828
from packaging import version
@@ -51,7 +51,7 @@
5151

5252

5353
if is_peft_available():
54-
from peft import PeftConfig, PeftModel, get_peft_model
54+
from peft import PeftConfig, get_peft_model
5555

5656
if is_vllm_available():
5757
from vllm import LLM, SamplingParams
@@ -249,7 +249,7 @@ def __init__(
249249
# Reference model
250250
if is_deepspeed_zero3_enabled():
251251
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
252-
elif peft_config is None:
252+
elif not is_peft_model(model):
253253
# If PEFT configuration is not provided, create a reference model based on the initial model.
254254
self.ref_model = create_reference_model(model)
255255
else:
@@ -491,16 +491,18 @@ def _move_model_to_vllm(self):
491491
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
492492
) as unwrapped_model:
493493
if is_compiled_module(unwrapped_model):
494-
state_dict = unwrapped_model._orig_mod.state_dict()
495-
elif is_peft_available() and isinstance(unwrapped_model, PeftModel):
494+
unwrapped_model = unwrapped_model._orig_mod
495+
if is_peft_model(unwrapped_model):
496496
unwrapped_model.merge_adapter()
497497
state_dict = unwrapped_model.state_dict()
498498
unwrapped_model.unmerge_adapter()
499+
# Remove base_model and base_layer prefixes
499500
state_dict = {
500-
k.removeprefix("base_model.model.").replace(".base_layer", ""): v
501-
for k, v in state_dict.items()
502-
if self.model.prefix not in k
501+
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
503502
}
503+
# Remove values with adapter prefix (example: "_lora")
504+
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
505+
# When module to save, remove its prefix and discard the original module
504506
state_dict = {
505507
k.replace("modules_to_save.default.", ""): v
506508
for k, v in state_dict.items()

0 commit comments

Comments
 (0)
Please sign in to comment.