Skip to content

Commit ead17e3

Browse files
committed
whoops. forgot the lora merge part
1 parent 2611708 commit ead17e3

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

trl/trainer/grpo_trainer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.utils.data
2424
import transformers
25-
from accelerate.utils import broadcast_object_list, gather, gather_object
25+
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model
2626
from accelerate.utils.other import is_compiled_module
2727
from datasets import Dataset, IterableDataset
2828
from packaging import version
@@ -47,7 +47,7 @@
4747
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
4848
from .callbacks import SyncRefModelCallback
4949
from .grpo_config import GRPOConfig
50-
from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax
50+
from .utils import generate_model_card, get_comet_experiment_url, get_lora_merged_state_dict, pad, selective_log_softmax
5151

5252

5353
if is_peft_available():
@@ -449,7 +449,9 @@ def _move_model_to_vllm(self):
449449
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
450450
) as unwrapped_model:
451451
if is_compiled_module(unwrapped_model):
452-
state_dict = unwrapped_model._orig_mod.state_dict()
452+
unwrapped_model = unwrapped_model._orig_mod
453+
if is_peft_model(unwrapped_model):
454+
state_dict = get_lora_merged_state_dict(unwrapped_model)
453455
else:
454456
state_dict = unwrapped_model.state_dict()
455457
if self.accelerator.is_main_process:

trl/trainer/utils.py

+62
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060

6161
if is_peft_available():
6262
from peft import LoraConfig, PeftConfig
63+
from peft.tuners.tuners_utils import onload_layer
64+
from peft.utils import ModulesToSaveWrapper, _get_submodules
6365

6466

6567
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
@@ -920,6 +922,66 @@ def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
920922

921923
return peft_config
922924

925+
def get_lora_merged_state_dict(
926+
model: torch.nn.Module,
927+
) -> dict:
928+
r"""
929+
Create and return a state_dict that has the LoRA deltas
930+
merged into the base model’s weights, without modifying `model` in place.
931+
932+
Arguments:
933+
model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.
934+
935+
Returns:
936+
dict: A state_dict of the merged parameters.
937+
"""
938+
939+
if not is_peft_available():
940+
raise ValueError(
941+
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
942+
"Make sure to run `pip install -U peft`."
943+
)
944+
945+
base_model_prefix = "base_model.model."
946+
state_dict = {}
947+
key_list = [key for key, _ in model.named_modules() if model.prefix not in key]
948+
for key in key_list:
949+
try:
950+
_, target, _ = _get_submodules(model, key)
951+
except AttributeError:
952+
continue
953+
with onload_layer(target):
954+
weight_key = key.replace(base_model_prefix, "") + ".weight"
955+
bias_key = key.replace(base_model_prefix, "") + ".bias"
956+
if hasattr(target, "base_layer"):
957+
target.merge(safe_merge=True, adapter_names=None)
958+
# get the state_dict of target.base_layer
959+
layer_state_dict = target.base_layer.state_dict()
960+
state_dict[weight_key] = layer_state_dict["weight"]
961+
elif isinstance(target, ModulesToSaveWrapper):
962+
# save any additional trainable modules part of `modules_to_save`
963+
new_module = target.modules_to_save[target.active_adapter]
964+
if hasattr(new_module, "base_layer"):
965+
# check if the module is itself a tuner layer
966+
new_module.merge(safe_merge=True, adapter_names=None)
967+
layer_state_dict = new_module.state_dict()
968+
state_dict[weight_key] = layer_state_dict["weight"]
969+
elif hasattr(target, "weight"):
970+
if any(
971+
skip in key
972+
for skip in [
973+
".original_module",
974+
".modules_to_save",
975+
".base_layer",
976+
]
977+
):
978+
continue
979+
layer_state_dict = target.state_dict()
980+
state_dict[weight_key] = layer_state_dict["weight"]
981+
if hasattr(target, "bias") and "bias" in layer_state_dict.keys():
982+
state_dict[bias_key] = layer_state_dict["bias"]
983+
return state_dict
984+
923985

924986
def get_exp_cap(value, decimal=4):
925987
"""

0 commit comments

Comments
 (0)