-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamically load LoRA weights when using vLLM #2730
base: fix-peft-vllm-grpo
Are you sure you want to change the base?
Dynamically load LoRA weights when using vLLM #2730
Conversation
Hey, really awesome work! I tried incorporating this PR locally, but ran into some issues with the Before I post the full traceback, do you have any code snippets where you initialize |
Hi @tchang1997, I haven't seen that. It sounds like it could be coming from vLLM, given I don't think TRL uses Triton (to my knowledge). Can you share the stack trace? |
Sure — see below. The underlying issue is seems to be raised by vLLM. FWIW, I also just tried
Here's the stack trace:``` [rank0]: Traceback (most recent call last): [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1721, in execute_model [rank0]: hidden_or_intermediate_states = model_executable( [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 539, in forward [rank0]: model_output = self.model(input_ids, positions, kv_caches, [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/compilation/decorators.py", line 170, in __call__ [rank0]: return self.forward(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 354, in forward [rank0]: hidden_states = self.get_input_embeddings(input_ids) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 339, in get_input_embeddings [rank0]: return self.embed_tokens(input_ids) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/lora/layers.py", line 260, in forward [rank0]: self.punica_wrapper.add_lora_embedding(full_output, ank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/lora/ops/triton_ops/sgmv_expand.py", line 220, in _sgmv_expand [rank0]: _sgmv_expand_kernel[grid]( [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in [rank0]: return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/runtime/jit.py", line 691, in run [rank0]: kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__ [rank0]: self.launch(*args, **kwargs) [rank0]: ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) [rank0]: The above exception was the direct cause of the following exception: [rank0]: Traceback (most recent call last): [rank0]: File "/home/ctrenton/steerability/rl.py", line 259, in [rank0]: trainer = GRPOTrainer( [rank0]: ^^^^^^^^^^^^ [rank0]: File "/home/ctrenton/trl/trl/trainer/grpo_trainer.py", line 338, in __init__ [rank0]: self.llm = LLM( [rank0]: ^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/utils.py", line 1039, in inner [rank0]: return fn(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 240, in __init__ [rank0]: self.llm_engine = self.engine_class.from_engine_args( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 482, in from_engine_args [rank0]: engine = cls( [rank0]: ^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 274, in __init__ [rank0]: self._initialize_kv_caches() [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 414, in _initialize_kv_caches [rank0]: self.model_executor.determine_num_available_blocks()) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 99, in determine_num_available_blocks [rank0]: results = self.collective_rpc("determine_num_available_blocks") [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 49, in collective_rpc [rank0]: answer = run_method(self.driver_worker, method, args, kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/utils.py", line 2208, in run_method [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/worker.py", line 228, in determine_num_available_blocks [rank0]: self.model_runner.profile_run() [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1236, in profile_run [rank0]: self._dummy_run(max_num_batched_tokens, max_num_seqs) [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1347, in _dummy_run [rank0]: self.execute_model(model_input, kv_caches, intermediate_tensors) [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper [rank0]: raise type(err)( [rank0]: ValueError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250207-175904.pkl): Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) ``` There's an identical message w/ `[rank1]` since I'm using 2 GPUs for training.Here's `trl env`:
In my main script, the error occurs when initializing the trainer:
Here's my peft config — everything's in a YAML file and parsed via
I can share the rest if it's important to know |
As an update, I traced the underlying issue to vLLM — the error comes from vLLM's initial profiling run, specifically the dummy LoRA requests, where some tensors are initialized on CPU. Interestingly, the issue I described goes away when I set the LLM device to |
Same issue on my side on 2xA100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this PR will is superseded by #2818 or not, but just in case it's not, I added my thoughts on it.
if peft_config is not None: | ||
model = get_peft_model(model, peft_config) | ||
if isinstance(peft_config, LoraConfig): | ||
lora_rank = peft_config.r |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just note that there could be different ranks for different layers: https://huggingface.co/docs/peft/v0.14.0/en/package_reference/lora#peft.LoraConfig.rank_pattern. Not sure if vLMM supports that or not. This might be relevant for determining the max_lora_rank
below.
from vllm.lora.request import LoRARequest | ||
|
||
# Enable runtime LoRA updating on every new checkpoint | ||
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of permanently setting env vars, as this can lead to hard to debug errors. Ideally, something like patch_environment
from accelerate could be used instead. Otherwise, a destructor (__del__
) could be defined on the appropriate object but that sounds brittle.
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: | ||
if isinstance(unwrapped_model, PeftModel): | ||
unwrapped_model = copy.deepcopy(unwrapped_model) | ||
unwrapped_model.merge_and_unload() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment as for the other PR: Consider using model.merge_adapter
, cleaning up the state dict, then model.unmerge_adapter
, instead of creating a whole copy of the model.
unwrapped_model.merge_and_unload() | ||
state_dict = unwrapped_model.state_dict() | ||
|
||
if self.accelerator.is_main_process: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT, the state_dict
is only needed on the main process. So how about moving the whole code block above into this if
block?
This PR implements the proposed improvement from #2725 and dynamically loads LoRA adapters into vLLM instead of merging LoRA weights back into the base model at each step. This will in practice be much faster and less memory intensive than merging.
The only caveat I would flag is that it does appear vLLM leaks host memory when dynamically loading LoRA adapters over and over. LoRAs are small, so this isn't necessarily going to cause failures, but the safest solution we've found when testing this internally has been to periodically recreate the
LLM
instance every 50 - 100 steps (this can safely be done within the same part of the code that writes out the LoRA checkpoint in my experience). Would be good to file an issue with vLLM team so they can investigate this at some point (@Jeffwan is this something you've encountered in your work on LoRA in vLLM?).This is an adaptation of some code my team is using on a fork of TRL, so would be great if someone like @qgallouedec would be willing to commandeer and test out further to ensure everything is working as intended.