Skip to content

Commit 8122166

Browse files
authored
⚡ Fix GRPO PEFT (#2725)
1 parent 7347c29 commit 8122166

File tree

3 files changed

+73
-9
lines changed

3 files changed

+73
-9
lines changed

Diff for: tests/test_grpo_trainer.py

+35
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,38 @@ def test_training_with_sync_ref_model(self):
498498
for n, param in previous_trainable_params.items():
499499
new_param = trainer.model.get_parameter(n)
500500
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
501+
502+
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
503+
@require_torch_accelerator
504+
@require_peft
505+
def test_training_vllm_and_peft(self):
506+
"""Test that training works with vLLM for generation."""
507+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
508+
509+
with tempfile.TemporaryDirectory() as tmp_dir:
510+
training_args = GRPOConfig(
511+
output_dir=tmp_dir,
512+
learning_rate=0.1, # increase the learning rate to speed up the test
513+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
514+
num_generations=3, # reduce the number of generations to reduce memory usage
515+
max_completion_length=32, # reduce the completion length to reduce memory usage
516+
use_vllm=True,
517+
report_to="none",
518+
)
519+
trainer = GRPOTrainer(
520+
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
521+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
522+
args=training_args,
523+
train_dataset=dataset,
524+
)
525+
526+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
527+
528+
trainer.train()
529+
530+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
531+
532+
# Check that the params have changed
533+
for n, param in previous_trainable_params.items():
534+
new_param = trainer.model.get_parameter(n)
535+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

Diff for: trl/models/utils.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from accelerate.utils import is_deepspeed_available
2222
from transformers import PreTrainedModel, PreTrainedTokenizer
23+
from transformers.utils.deprecation import deprecate_kwarg
2324

2425
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
2526

@@ -37,8 +38,6 @@
3738
from deepspeed.runtime.engine import DeepSpeedEngine
3839
from torch.nn.parallel.distributed import DistributedDataParallel
3940

40-
from .modeling_base import PreTrainedModelWrapper
41-
4241

4342
# TODO: Add Abstract Base Class if more formats are added
4443
@dataclass
@@ -176,18 +175,34 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
176175

177176

178177
@contextmanager
178+
@deprecate_kwarg("is_peft_model", "0.16.0", warn_if_greater_or_equal_version=True)
179179
def unwrap_model_for_generation(
180180
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
181181
accelerator: "Accelerator",
182-
is_peft_model: bool = False,
183182
gather_deepspeed3_params: bool = True,
184-
) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]:
185-
"""Context manager to unwrap a model for generation.
186-
For ZeRO-3 models, we gather the weights once to speed up generation.
183+
):
184+
"""
185+
Context manager to unwrap distributed or accelerated models for generation tasks.
186+
187+
Args:
188+
model (`Union[DistributedDataParallel, DeepSpeedEngine]`):
189+
Model to be unwrapped.
190+
accelerator (`~accelerate.Accelerator`):
191+
Accelerator instance managing the model.
192+
gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
193+
Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
194+
can be more memory-efficient but may lead to slower generation times.
195+
196+
Yields:
197+
Unwrapped model.
198+
199+
Example:
200+
```python
201+
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
202+
generated_outputs = unwrapped_model.generate(input_ids)
203+
```
187204
"""
188205
unwrapped_model = accelerator.unwrap_model(model)
189-
if is_peft_model:
190-
unwrapped_model.pretrained_model.disable_adapter()
191206
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
192207
if not gather_deepspeed3_params:
193208
yield accelerator.unwrap_model(model)

Diff for: trl/trainer/grpo_trainer.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252

5353
if is_peft_available():
54-
from peft import PeftConfig, get_peft_model
54+
from peft import PeftConfig, PeftModel, get_peft_model
5555

5656
if is_vllm_available():
5757
from vllm import LLM, SamplingParams
@@ -492,6 +492,20 @@ def _move_model_to_vllm(self):
492492
) as unwrapped_model:
493493
if is_compiled_module(unwrapped_model):
494494
state_dict = unwrapped_model._orig_mod.state_dict()
495+
elif isinstance(unwrapped_model, PeftModel):
496+
unwrapped_model.merge_adapter()
497+
state_dict = unwrapped_model.state_dict()
498+
unwrapped_model.unmerge_adapter()
499+
state_dict = {
500+
k.removeprefix("base_model.model.").replace(".base_layer", ""): v
501+
for k, v in state_dict.items()
502+
if self.model.prefix not in k
503+
}
504+
state_dict = {
505+
k.replace("modules_to_save.default.", ""): v
506+
for k, v in state_dict.items()
507+
if "original_module" not in k
508+
}
495509
else:
496510
state_dict = unwrapped_model.state_dict()
497511
if self.accelerator.is_main_process:

0 commit comments

Comments
 (0)