Skip to content

Commit dd55d46

Browse files
committed
peft + grpo + vllm
1 parent 8122166 commit dd55d46

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

Diff for: tests/test_grpo_trainer.py

+8
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,19 @@ def test_training_vllm_and_peft(self):
516516
use_vllm=True,
517517
report_to="none",
518518
)
519+
lora_config = LoraConfig(
520+
r=4,
521+
lora_alpha=8,
522+
lora_dropout=0.05,
523+
target_modules="all-linear",
524+
modules_to_save=["embed_tokens", "lm_head"],
525+
)
519526
trainer = GRPOTrainer(
520527
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
521528
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
522529
args=training_args,
523530
train_dataset=dataset,
531+
peft_config=lora_config,
524532
)
525533

526534
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

Diff for: trl/trainer/grpo_trainer.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.utils.data
2424
import transformers
25-
from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed
25+
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
2626
from accelerate.utils.other import is_compiled_module
2727
from datasets import Dataset, IterableDataset
2828
from packaging import version
@@ -491,21 +491,20 @@ def _move_model_to_vllm(self):
491491
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
492492
) as unwrapped_model:
493493
if is_compiled_module(unwrapped_model):
494-
state_dict = unwrapped_model._orig_mod.state_dict()
495-
elif isinstance(unwrapped_model, PeftModel):
494+
unwrapped_model = unwrapped_model._orig_mod
495+
if is_peft_model(unwrapped_model):
496496
unwrapped_model.merge_adapter()
497497
state_dict = unwrapped_model.state_dict()
498-
unwrapped_model.unmerge_adapter()
499498
state_dict = {
500-
k.removeprefix("base_model.model.").replace(".base_layer", ""): v
499+
k.removeprefix("base_model.model.")
500+
.removeprefix("base_model.model.")
501+
.replace(".default", "")
502+
.replace(".base_layer", "")
503+
.replace(".modules_to_save", ""): v
501504
for k, v in state_dict.items()
502-
if self.model.prefix not in k
503-
}
504-
state_dict = {
505-
k.replace("modules_to_save.default.", ""): v
506-
for k, v in state_dict.items()
507-
if "original_module" not in k
505+
if unwrapped_model.prefix not in k and "original_module" not in k
508506
}
507+
unwrapped_model.unmerge_adapter()
509508
else:
510509
state_dict = unwrapped_model.state_dict()
511510
if self.accelerator.is_main_process:

0 commit comments

Comments
 (0)