-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
👨👩👧 GRPO + PEFT + vLLM #2818
👨👩👧 GRPO + PEFT + vLLM #2818
Conversation
trl/trainer/utils.py
Outdated
weight_key = key.replace(base_model_prefix, "") + ".weight" | ||
bias_key = key.replace(base_model_prefix, "") + ".bias" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is pretty janky, so would love feedback on making it better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it work to iterate through model.base_model.model.named_modules()
at L947 to get the named parameters w/o the "model.base_model" prefix?
trl/trainer/utils.py
Outdated
if any( | ||
skip in key | ||
for skip in [ | ||
".original_module", | ||
".modules_to_save", | ||
".base_layer", | ||
] | ||
): | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this
ead17e3
to
182ade9
Compare
@qgallouedec The downside there is that you're limited to the lora support in vllm, which means no DoRA support. This approach almost any peft adapter type could be used. While LoRA does converge pretty quickly too compared to full parameter training, dora seems to be more performant. ![]() |
This seems quite reasonable, thank you for the clear explanation. |
Another pointer that could be useful:
From @BenjaminBossan |
I tried
but the state dict results still has the prefix of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR.
To elaborate on the quote by Quentin, the steps would be:
- Call
model.merge_adapter()
. - Get the
state_dict
of the merged model. - Clean up the
state_dict
: Since the base weights already contain the merged LoRA weights, we can remove all LoRA weights - Call
model.unmerge_adapter()
if we need to restore the previous state (note thatunmerge_adapter
unmerges all adapters, so if some were already merged before step 1, they need to be re-merged, but it's probably not relevant here)
Here is a small demonstration in code:
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
model_id = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(model_id)
config = LoraConfig()
model = get_peft_model(model, config)
model.merge_adapter()
sd = model.state_dict()
new_sd = {k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in sd.items() if model.prefix not in k}
model.unmerge_adapter()
I've added the suggested modification to this branch: #2725 it seems to work...! EDIT: DORA included |
Nice, I added a comment there. Hopefully, one of these branches can be merged soon :) |
a031699
to
dd55d46
Compare
I re-did this PR to account for the other changes, and also updated the test to use lora. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this PR is still required after #2725 has been merged, but I did a quick review just in case.
trl/trainer/grpo_trainer.py
Outdated
k.removeprefix("base_model.model.") | ||
.removeprefix("base_model.model.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate, you can remove the 2nd line.
trl/trainer/grpo_trainer.py
Outdated
k.removeprefix("base_model.model.").replace(".base_layer", ""): v | ||
k.removeprefix("base_model.model.") | ||
.removeprefix("base_model.model.") | ||
.replace(".default", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave the same comment as I did on #2725:
Note here that the adapter name can be different from "default". You could get the adapter name from model.active_adapters, which is a list of all active adapters. I assume in this context, there can only ever be one (raise an error when more?), so taking the first item should work.
thanks for the followup @BenjaminBossan ! |
@@ -249,7 +249,7 @@ def __init__( | |||
# Reference model | |||
if is_deepspeed_zero3_enabled(): | |||
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) | |||
elif peft_config is None: | |||
elif not is_peft_model(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allows to support model that is already wrapped by peft
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Using Qwen1.5 instruct model I face the following error:
the format is:
with the following vllm settings:
Another weird thing that I noticed is that
Why the |
What rank and dataset? It learns pretty quickly with rank 64 o. The gsm8k dataset |
Current config lora_config = LoraConfig( what do you use as your alpha |
What does this PR do?
unlocks PEFT + GRPO + vllm without the complexity of shipping lora weights to vllm via the REST API. This implementation simply merges the lora weights into the base model and ships that to vllm using the existing python API.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.