Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically load LoRA weights when using vLLM #2730

Open
wants to merge 1 commit into
base: fix-peft-vllm-grpo
Choose a base branch
from

Conversation

tgaddair
Copy link

@tgaddair tgaddair commented Feb 1, 2025

This PR implements the proposed improvement from #2725 and dynamically loads LoRA adapters into vLLM instead of merging LoRA weights back into the base model at each step. This will in practice be much faster and less memory intensive than merging.

The only caveat I would flag is that it does appear vLLM leaks host memory when dynamically loading LoRA adapters over and over. LoRAs are small, so this isn't necessarily going to cause failures, but the safest solution we've found when testing this internally has been to periodically recreate the LLM instance every 50 - 100 steps (this can safely be done within the same part of the code that writes out the LoRA checkpoint in my experience). Would be good to file an issue with vLLM team so they can investigate this at some point (@Jeffwan is this something you've encountered in your work on LoRA in vLLM?).

This is an adaptation of some code my team is using on a fork of TRL, so would be great if someone like @qgallouedec would be willing to commandeer and test out further to ensure everything is working as intended.

@tgaddair tgaddair mentioned this pull request Feb 1, 2025
5 tasks
@tchang1997
Copy link

Hey, really awesome work! I tried incorporating this PR locally, but ran into some issues with the __init__ call in GRPOTrainer with some ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) issue. I'm using accelerate launch if that's relevant.

Before I post the full traceback, do you have any code snippets where you initialize GRPOTrainer successfully? I'm also working on a custom fork of trl, so I just want to make sure I didn't induce some silly option clash. I'm able to proceed w/ full-parameter tuning successfully if I don't use a peft config FWIW.

@tgaddair
Copy link
Author

tgaddair commented Feb 7, 2025

Hi @tchang1997, I haven't seen that. It sounds like it could be coming from vLLM, given I don't think TRL uses Triton (to my knowledge). Can you share the stack trace?

@tchang1997
Copy link

tchang1997 commented Feb 8, 2025

Sure — see below. The underlying issue is seems to be raised by vLLM. FWIW, I also just tried

  • initializing LLM manually in the interpreter w/ the same args as I use in the script -> no issues ✅
  • accelerate launch with --num-processes 1 (instead of my usual 2) -> issue persists ❌
  • python script.py instead of accelerate launch script.py -> issue persists ❌
Here's the stack trace: ``` [rank0]: Traceback (most recent call last): [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1721, in execute_model [rank0]: hidden_or_intermediate_states = model_executable( [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 539, in forward [rank0]: model_output = self.model(input_ids, positions, kv_caches, [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/compilation/decorators.py", line 170, in __call__ [rank0]: return self.forward(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 354, in forward [rank0]: hidden_states = self.get_input_embeddings(input_ids) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 339, in get_input_embeddings [rank0]: return self.embed_tokens(input_ids) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/lora/layers.py", line 260, in forward [rank0]: self.punica_wrapper.add_lora_embedding(full_output, ank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/lora/ops/triton_ops/sgmv_expand.py", line 220, in _sgmv_expand [rank0]: _sgmv_expand_kernel[grid]( [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in [rank0]: return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/runtime/jit.py", line 691, in run [rank0]: kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__ [rank0]: self.launch(*args, **kwargs) [rank0]: ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) [rank0]: The above exception was the direct cause of the following exception: [rank0]: Traceback (most recent call last): [rank0]: File "/home/ctrenton/steerability/rl.py", line 259, in [rank0]: trainer = GRPOTrainer( [rank0]: ^^^^^^^^^^^^ [rank0]: File "/home/ctrenton/trl/trl/trainer/grpo_trainer.py", line 338, in __init__ [rank0]: self.llm = LLM( [rank0]: ^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/utils.py", line 1039, in inner [rank0]: return fn(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 240, in __init__ [rank0]: self.llm_engine = self.engine_class.from_engine_args( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 482, in from_engine_args [rank0]: engine = cls( [rank0]: ^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 274, in __init__ [rank0]: self._initialize_kv_caches() [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 414, in _initialize_kv_caches [rank0]: self.model_executor.determine_num_available_blocks()) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 99, in determine_num_available_blocks [rank0]: results = self.collective_rpc("determine_num_available_blocks") [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 49, in collective_rpc [rank0]: answer = run_method(self.driver_worker, method, args, kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/utils.py", line 2208, in run_method [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/worker.py", line 228, in determine_num_available_blocks [rank0]: self.model_runner.profile_run() [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1236, in profile_run [rank0]: self._dummy_run(max_num_batched_tokens, max_num_seqs) [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1347, in _dummy_run [rank0]: self.execute_model(model_input, kv_caches, intermediate_tensors) [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper [rank0]: raise type(err)( [rank0]: ValueError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250207-175904.pkl): Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) ``` There's an identical message w/ `[rank1]` since I'm using 2 GPUs for training.
Here's `trl env`:
- Platform: Linux-5.15.0-131-generic-x86_64-with-glibc2.35
- Python version: 3.12.8
- PyTorch version: 2.5.1
- CUDA device(s): NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000
- Transformers version: 4.48.2
- Accelerate version: 1.3.0
- Accelerate config:
  - compute_environment: LOCAL_MACHINE
  - distributed_type: DEEPSPEED
  - mixed_precision: bf16
  - use_cpu: False
  - debug: False
  - num_processes: 2
  - machine_rank: 0
  - num_machines: 1
  - rdzv_backend: static
  - same_network: True
  - main_training_function: main
  - enable_cpu_affinity: False
  - deepspeed_config: {'gradient_accumulation_steps': 4, 'gradient_clipping': 1.0, 'offload_optimizer_device': 'cpu', 'offload_param_device': 'cpu', 'zero3_init_flag': False, 'zero_stage': 2}
  - downcast_bf16: no
  - tpu_use_cluster: False
  - tpu_use_sudo: False
  - tpu_env: []
  - dynamo_config: {'dynamo_backend': 'INDUCTOR'}
- Datasets version: 3.2.0
- HF Hub version: 0.28.1
- TRL version: 0.15.0.dev0+848352d
- bitsandbytes version: 0.45.1
- DeepSpeed version: 0.16.3
- Diffusers version: 0.32.2
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: 1.61.0
- PEFT version: 0.14.0

In my main script, the error occurs when initializing the trainer:

trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_funcs,
        args=training_config,
        train_dataset=dataset,
        eval_dataset=dataset, 
        peft_config=peft_config 
    )

Here's my peft config — everything's in a YAML file and parsed via HfArgumentParser. I've verified that the LoRA config looks exactly as expected.

use_peft: True
lora_r: 16
lora_alpha: 64
lora_dropout: 0.05
lora_target_modules: 
    - "q_proj"
    - "k_proj"
    - "v_proj"
    - "o_proj"
    - "gate_proj"
    - "up_proj"
    - "down_proj"

I can share the rest if it's important to know training_config and model_config, but if I remove the above lines from my config (and don't use peft at all) initialization proceeds smoothly. If you have a set of training args/model args that works, I'm happy to test; if it still fails on my end, that probably means my customizations to grpo_trainer.py are clashing (though I didn't touch __init__).

@tchang1997
Copy link

As an update, I traced the underlying issue to vLLM — the error comes from vLLM's initial profiling run, specifically the dummy LoRA requests, where some tensors are initialized on CPU.

Interestingly, the issue I described goes away when I set the LLM device to cuda:0 instead of cuda:1, and inspecting llm.llm_engine.model_executor.driver_worker.model_runner.model shows that vLLM is indeed running a LoRA-ified model. I still don't know the root cause of my issue, but it's unlikely to come from these changes.

@ingambe
Copy link

ingambe commented Feb 8, 2025

Same issue on my side on 2xA100
Placing on cuda:0 fixes it but then all models are on the same device

@qgallouedec qgallouedec mentioned this pull request Feb 10, 2025
5 tasks
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this PR will is superseded by #2818 or not, but just in case it's not, I added my thoughts on it.

if peft_config is not None:
model = get_peft_model(model, peft_config)
if isinstance(peft_config, LoraConfig):
lora_rank = peft_config.r
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note that there could be different ranks for different layers: https://huggingface.co/docs/peft/v0.14.0/en/package_reference/lora#peft.LoraConfig.rank_pattern. Not sure if vLMM supports that or not. This might be relevant for determining the max_lora_rank below.

from vllm.lora.request import LoRARequest

# Enable runtime LoRA updating on every new checkpoint
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of permanently setting env vars, as this can lead to hard to debug errors. Ideally, something like patch_environment from accelerate could be used instead. Otherwise, a destructor (__del__) could be defined on the appropriate object but that sounds brittle.

with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
if isinstance(unwrapped_model, PeftModel):
unwrapped_model = copy.deepcopy(unwrapped_model)
unwrapped_model.merge_and_unload()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment as for the other PR: Consider using model.merge_adapter, cleaning up the state dict, then model.unmerge_adapter, instead of creating a whole copy of the model.

unwrapped_model.merge_and_unload()
state_dict = unwrapped_model.state_dict()

if self.accelerator.is_main_process:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT, the state_dict is only needed on the main process. So how about moving the whole code block above into this if block?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants