Skip to content

Commit a031699

Browse files
committed
handle native lora merge before shipping to vllm
1 parent 182ade9 commit a031699

File tree

2 files changed

+13
-62
lines changed

2 files changed

+13
-62
lines changed

trl/trainer/grpo_trainer.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
4848
from .callbacks import SyncRefModelCallback
4949
from .grpo_config import GRPOConfig
50-
from .utils import generate_model_card, get_comet_experiment_url, get_lora_merged_state_dict, pad, selective_log_softmax
50+
from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax
5151

5252

5353
if is_peft_available():
@@ -451,7 +451,18 @@ def _move_model_to_vllm(self):
451451
if is_compiled_module(unwrapped_model):
452452
unwrapped_model = unwrapped_model._orig_mod
453453
if is_peft_model(unwrapped_model):
454-
state_dict = get_lora_merged_state_dict(unwrapped_model)
454+
unwrapped_model.merge_adapter()
455+
state_dict = unwrapped_model.state_dict()
456+
state_dict = {
457+
k.removeprefix("base_model.model.")
458+
.removeprefix("base_model.model.")
459+
.replace(".default", "")
460+
.replace(".base_layer", "")
461+
.replace(".modules_to_save", ""): v
462+
for k, v in state_dict.items()
463+
if unwrapped_model.prefix not in k and "original_module" not in k
464+
}
465+
unwrapped_model.unmerge_adapter()
455466
else:
456467
state_dict = unwrapped_model.state_dict()
457468
if self.accelerator.is_main_process:

trl/trainer/utils.py

-60
Original file line numberDiff line numberDiff line change
@@ -922,66 +922,6 @@ def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
922922

923923
return peft_config
924924

925-
def get_lora_merged_state_dict(
926-
model: torch.nn.Module,
927-
) -> dict:
928-
r"""
929-
Create and return a state_dict that has the LoRA deltas
930-
merged into the base model’s weights, without modifying `model` in place.
931-
932-
Arguments:
933-
model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.
934-
935-
Returns:
936-
dict: A state_dict of the merged parameters.
937-
"""
938-
939-
if not is_peft_available():
940-
raise ValueError(
941-
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
942-
"Make sure to run `pip install -U peft`."
943-
)
944-
945-
base_model_prefix = "base_model.model."
946-
state_dict = {}
947-
key_list = [key for key, _ in model.named_modules() if model.prefix not in key]
948-
for key in key_list:
949-
try:
950-
_, target, _ = _get_submodules(model, key)
951-
except AttributeError:
952-
continue
953-
with onload_layer(target):
954-
weight_key = key.replace(base_model_prefix, "") + ".weight"
955-
bias_key = key.replace(base_model_prefix, "") + ".bias"
956-
if hasattr(target, "base_layer"):
957-
target.merge(safe_merge=True, adapter_names=None)
958-
# get the state_dict of target.base_layer
959-
layer_state_dict = target.base_layer.state_dict()
960-
state_dict[weight_key] = layer_state_dict["weight"]
961-
elif isinstance(target, ModulesToSaveWrapper):
962-
# save any additional trainable modules part of `modules_to_save`
963-
new_module = target.modules_to_save[target.active_adapter]
964-
if hasattr(new_module, "base_layer"):
965-
# check if the module is itself a tuner layer
966-
new_module.merge(safe_merge=True, adapter_names=None)
967-
layer_state_dict = new_module.state_dict()
968-
state_dict[weight_key] = layer_state_dict["weight"]
969-
elif hasattr(target, "weight"):
970-
if any(
971-
skip in key
972-
for skip in [
973-
".original_module",
974-
".modules_to_save",
975-
".base_layer",
976-
]
977-
):
978-
continue
979-
layer_state_dict = target.state_dict()
980-
state_dict[weight_key] = layer_state_dict["weight"]
981-
if hasattr(target, "bias") and "bias" in layer_state_dict.keys():
982-
state_dict[bias_key] = layer_state_dict["bias"]
983-
return state_dict
984-
985925

986926
def get_exp_cap(value, decimal=4):
987927
"""

0 commit comments

Comments
 (0)