Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👨‍👩‍👧 GRPO + PEFT + vLLM #2818

Merged
merged 5 commits into from
Feb 13, 2025

Conversation

winglian
Copy link
Contributor

What does this PR do?

unlocks PEFT + GRPO + vllm without the complexity of shipping lora weights to vllm via the REST API. This implementation simply merges the lora weights into the base model and ships that to vllm using the existing python API.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Comment on lines 954 to 955
weight_key = key.replace(base_model_prefix, "") + ".weight"
bias_key = key.replace(base_model_prefix, "") + ".bias"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is pretty janky, so would love feedback on making it better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work to iterate through model.base_model.model.named_modules() at L947 to get the named parameters w/o the "model.base_model" prefix?

Comment on lines 970 to 978
if any(
skip in key
for skip in [
".original_module",
".modules_to_save",
".base_layer",
]
):
continue
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this

@qgallouedec
Copy link
Member

@winglian just to point out a different approach: #2730

@winglian
Copy link
Contributor Author

@qgallouedec The downside there is that you're limited to the lora support in vllm, which means no DoRA support. This approach almost any peft adapter type could be used. While LoRA does converge pretty quickly too compared to full parameter training, dora seems to be more performant.

Screenshot 2025-02-10 at 9 25 06 AM

@qgallouedec
Copy link
Member

This seems quite reasonable, thank you for the clear explanation.

@qgallouedec
Copy link
Member

qgallouedec commented Feb 10, 2025

Another pointer that could be useful:

It is possible to call model.merge_adapter (optionally with adapter_names argument), then model.state_dict(), then model.unmerge_adapter.
The state_dict may require some clean up though, depending on what you need to do with it (I couldn't infer that from the PR).
By clean up, I mean: After merge_and_unload the model looks like the base model. But merge_adapter keeps the LoRA structure, with the wrapped base model, LoRA weights etc. still being present in the state_dict.

From @BenjaminBossan

@winglian
Copy link
Contributor Author

I tried

                unwrapped_model.merge_and_unload()
                state_dict = unwrapped_model.base_model.model.state_dict()
                unwrapped_model.unmerge_adapter()

but the state dict results still has the prefix of base_model.model.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this PR.

To elaborate on the quote by Quentin, the steps would be:

  1. Call model.merge_adapter().
  2. Get the state_dict of the merged model.
  3. Clean up the state_dict: Since the base weights already contain the merged LoRA weights, we can remove all LoRA weights
  4. Call model.unmerge_adapter() if we need to restore the previous state (note that unmerge_adapter unmerges all adapters, so if some were already merged before step 1, they need to be re-merged, but it's probably not relevant here)

Here is a small demonstration in code:

from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

model_id = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(model_id)
config = LoraConfig()
model = get_peft_model(model, config)
model.merge_adapter()
sd = model.state_dict()
new_sd = {k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in sd.items() if model.prefix not in k}
model.unmerge_adapter()

@qgallouedec
Copy link
Member

qgallouedec commented Feb 10, 2025

I've added the suggested modification to this branch: #2725 it seems to work...! EDIT: DORA included

@BenjaminBossan
Copy link
Member

I've added the suggested modification to this branch: #2725 it seems to work...! EDIT: DORA included

Nice, I added a comment there. Hopefully, one of these branches can be merged soon :)

@winglian
Copy link
Contributor Author

I re-did this PR to account for the other changes, and also updated the test to use lora.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this PR is still required after #2725 has been merged, but I did a quick review just in case.

Comment on lines 499 to 500
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate, you can remove the 2nd line.

k.removeprefix("base_model.model.").replace(".base_layer", ""): v
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".default", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave the same comment as I did on #2725:

Note here that the adapter name can be different from "default". You could get the adapter name from model.active_adapters, which is a list of all active adapters. I assume in this context, there can only ever be one (raise an error when more?), so taking the first item should work.

@qgallouedec
Copy link
Member

thanks for the followup @BenjaminBossan !

@@ -249,7 +249,7 @@ def __init__(
# Reference model
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif peft_config is None:
elif not is_peft_model(model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allows to support model that is already wrapped by peft

@qgallouedec qgallouedec changed the title GRPO + PEFT + vLLM 👨‍👩‍👧 GRPO + PEFT + vLLM Feb 13, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec merged commit 5c9cf20 into huggingface:main Feb 13, 2025
@mehdiataei
Copy link

Using Qwen1.5 instruct model I face the following error:


[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2171, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2531, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3669, in training_step
[rank0]:     inputs = self._prepare_inputs(inputs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/trl/trainer/grpo_trainer.py", line 535, in _prepare_inputs
[rank0]:     self._move_model_to_vllm()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/trl/trainer/grpo_trainer.py", line 515, in _move_model_to_vllm
[rank0]:     llm_model.load_weights(state_dict.items())
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 515, in load_weights
[rank0]:     return loader.load_weights(weights)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/utils.py", line 235, in load_weights
[rank0]:     autoloaded_weights = set(self._load_module("", self.module, weights))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/utils.py", line 224, in _load_module
[rank0]:     raise ValueError(msg)
[rank0]: ValueError: There is no module or parameter named 'base_model' in Qwen2ForCausalLM

the format is:

    trainer = GRPOTrainer(
        model=model,
        reward_funcs=[format_reward, judge_reward],
        args=training_args,
        train_dataset=dataset,
        peft_config=lora_config,
    )

with the following vllm settings:

    use_vllm=True,                       # Whether to use vLLM for faster generation (default: False)
    vllm_device="cuda:7",                   # Device for vLLM generation (e.g., "cuda:1"); "auto" selects the next available GPU
    vllm_gpu_memory_utilization=0.4,       # Fraction of GPU memory to reserve for vLLM (default: 0.9)
    vllm_dtype="auto",                    # Data type for vLLM generation; "auto" lets vLLM decide based on model config
    vllm_max_model_len=512,              # Optional maximum model length for vLLM; if None, uses the model's context size

Another weird thing that I noticed is that

INFO 02-13 17:12:27 model_runner.py:1115] Loading model weights took 0.0000 GB
^[[AINFO 02-13 17:12:28 worker.py:267] Memory profiling takes 0.48 seconds
INFO 02-13 17:12:28 worker.py:267] the current vLLM instance can use total_gpu_memory (39.39GiB) x gpu_memory_utilization (0.40) = 15.76GiB
INFO 02-13 17:12:28 worker.py:267] model weights take 0.00GiB; non_torch_memory takes 0.00GiB; PyTorch activation peak memory takes 0.00GiB; the rest of the memory reserved for KV Cache is 15.76GiB

Why the model weights take 0.00GiB?

@zaddy6
Copy link

zaddy6 commented Feb 14, 2025

I noticed training without LORA leads to better performance, here is an example without LORA it starts to max the rewards at 1k steps, with Lora it doesnt learn
image

@winglian
Copy link
Contributor Author

I noticed training without LORA leads to better performance, here is an example without LORA it starts to max the rewards at 1k steps, with Lora it doesnt learn

What rank and dataset? It learns pretty quickly with rank 64 o. The gsm8k dataset

@zaddy6
Copy link

zaddy6 commented Feb 14, 2025

I noticed training without LORA leads to better performance, here is an example without LORA it starts to max the rewards at 1k steps, with Lora it doesnt learn

What rank and dataset? It learns pretty quickly with rank 64 o. The gsm8k dataset

Current config

lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules="all-linear",
lora_dropout=0.05,
use_dora=True,
)

what do you use as your alpha

@wusijie123
Copy link

Total steps changed when using this code
image
why the Total optimization steps = Num examples * Num Epochs / Gradient Accumulation steps?
in previews version,Total optimization steps = Num examples * Num Epochs / Total train batch size

By the way,even using lora , most of the time I am waiting for VLLM to generate results. Is there any way to speed up the generation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants