Skip to content

[Bug]: Speculative decoding reports errors when loading target model using distributed inference (VLLM's offical Ray setup) #12841

@Neo9061

Description

@Neo9061

Your current environment

  1. vllm: open-ai latest container
  2. The ray cluster I set up is two nodes of 8 x H100. I setup the ray cluster, check ray status being okay, and run following python script within the container.
  3. I am doing offline distributed inference with official guided instruction using ray.
  4. I am able to successfully start the model with distributed inference without speculative decoding via VLLM class.
  5. Then when I try to pass in the speculative argument to the VLLM class and it reports error.

🐛 Describe the bug

Reproducible code is simply below. Note. If you remove the speculative decoding arguments, the model can be loaded successfully.

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0, max_tokens=512)

llm = LLM(
    model="meta-llama/Llama-3.1-70B-Instruct",
    tensor_parallel_size=16,
    speculative_model="meta-llama/Llama-3.2-1B-Instruct",
    speculative_draft_tensor_parallel_size=1,
    num_speculative_tokens=5,
    disable_log_stats=False,
    enforce_eager=True,
    trust_remote_code=True
)


import time
time.sleep(5)

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


The error message it gave is following.

INFO 02-06 08:03:56 model_runner.py:1111] Starting to load model /root/models/llama-3-3-70b/Llama-3.1-70B-Instruct...
^MLoading safetensors checkpoint shards:   0% Completed | 0/30 [00:00<?, ?it/s]
^MLoading safetensors checkpoint shards:   3% Completed | 1/30 [00:08<04:08,  8.58s/it]
^MLoading safetensors checkpoint shards:   7% Completed | 2/30 [00:17<04:10,  8.94s/it]
^MLoading safetensors checkpoint shards:  10% Completed | 3/30 [00:25<03:43,  8.29s/it]
^MLoading safetensors checkpoint shards:  13% Completed | 4/30 [00:35<03:57,  9.15s/it]
^MLoading safetensors checkpoint shards:  17% Completed | 5/30 [00:43<03:38,  8.75s/it]
^MLoading safetensors checkpoint shards:  20% Completed | 6/30 [00:53<03:38,  9.09s/it]
^MLoading safetensors checkpoint shards:  23% Completed | 7/30 [01:00<03:12,  8.38s/it]
^MLoading safetensors checkpoint shards:  27% Completed | 8/30 [01:08<03:03,  8.36s/it]
^MLoading safetensors checkpoint shards:  30% Completed | 9/30 [01:19<03:09,  9.00s/it]
^MLoading safetensors checkpoint shards:  33% Completed | 10/30 [01:27<02:56,  8.80s/it]
^MLoading safetensors checkpoint shards:  37% Completed | 11/30 [01:37<02:53,  9.11s/it]
^MLoading safetensors checkpoint shards:  40% Completed | 12/30 [01:45<02:37,  8.78s/it]
^MLoading safetensors checkpoint shards:  43% Completed | 13/30 [01:46<01:48,  6.35s/it]
^MLoading safetensors checkpoint shards:  47% Completed | 14/30 [01:56<01:58,  7.43s/it]
^MLoading safetensors checkpoint shards:  50% Completed | 15/30 [02:05<01:58,  7.92s/it]
^MLoading safetensors checkpoint shards:  53% Completed | 16/30 [02:12<01:48,  7.76s/it]
^MLoading safetensors checkpoint shards:  57% Completed | 17/30 [02:20<01:40,  7.71s/it]
^MLoading safetensors checkpoint shards:  60% Completed | 18/30 [02:29<01:38,  8.21s/it]
^MLoading safetensors checkpoint shards:  63% Completed | 19/30 [02:39<01:36,  8.75s/it]
^MLoading safetensors checkpoint shards:  67% Completed | 20/30 [02:48<01:29,  8.91s/it]
^MLoading safetensors checkpoint shards:  70% Completed | 21/30 [02:58<01:21,  9.02s/it]
^MLoading safetensors checkpoint shards:  73% Completed | 22/30 [03:07<01:12,  9.09s/it]
^MLoading safetensors checkpoint shards:  77% Completed | 23/30 [03:15<01:00,  8.71s/it]
^MLoading safetensors checkpoint shards:  80% Completed | 24/30 [03:24<00:53,  8.89s/it]
^MLoading safetensors checkpoint shards:  83% Completed | 25/30 [03:33<00:45,  9.01s/it]
^MLoading safetensors checkpoint shards:  87% Completed | 26/30 [03:41<00:33,  8.50s/it]
^MLoading safetensors checkpoint shards:  90% Completed | 27/30 [03:48<00:24,  8.06s/it]
^MLoading safetensors checkpoint shards:  93% Completed | 28/30 [03:57<00:16,  8.38s/it]
^MLoading safetensors checkpoint shards:  97% Completed | 29/30 [04:06<00:08,  8.55s/it]
^MLoading safetensors checkpoint shards: 100% Completed | 30/30 [04:14<00:00,  8.41s/it]
^MLoading safetensors checkpoint shards: 100% Completed | 30/30 [04:14<00:00,  8.48s/it]

INFO 02-06 08:08:11 model_runner.py:1116] Loading model weights took 8.4050 GB
INFO 02-06 08:08:11 model_runner.py:1111] Starting to load model /root/models/eagle-head/Llama-3.2-1B-Instruct...
^MLoading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
^MLoading safetensors checkpoint shards: 100% Completed | 1/1 [00:17<00:00, 17.94s/it]
^MLoading safetensors checkpoint shards: 100% Completed | 1/1 [00:17<00:00, 17.94s/it]

^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m INFO 02-06 08:08:12 model_runner.py:1116] Loading model weights took 8.4050 GB
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572] Error executing method 'init_device'. This might cause deadlock in distributed execution.
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572] Traceback (most recent call last):
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 564, in execute_method
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]     return run_method(target, method, args, kwargs)
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2208, in run_method
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]     return func(*args, **kwargs)
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]            ^^^^^^^^^^^^^^^^^^^^^
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]   File "/usr/local/lib/python3.12/dist-packages/vllm/spec_decode/spec_decode_worker.py", line 329, in init_device
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]     self.spec_decode_sampler.init_tensors(self.rank,
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/spec_decode_base_sampler.py", line 54, in init_tensors
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]     self.num_accepted_tokens = torch.tensor(0,
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]                                ^^^^^^^^^^^^^^^
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572] RuntimeError: CUDA error: invalid device ordinal
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
^[[36m(RayWorkerWrapper pid=14182, ip=172.31.18.145)^[[0m ERROR 02-06 08:08:12 worker_base.py:572]
^[[36m(RayWorkerWrapper pid=14179, ip=172.31.18.145)^[[0m WARNING 02-06 08:03:55 custom_all_reduce.py:82] Custom allreduce is disabled because this process group spans across nodes.^[[32m [repeated 14x across cluster]^[[0m
^[[36m(RayWorkerWrapper pid=14179, ip=172.31.18.145)^[[0m INFO 02-06 08:03:55 model_runner.py:1111] Starting to load model /root/models/llama-3-3-70b/Llama-3.1-70B-Instruct...^[[32m [repeated 14x across cluster]^[[0m
^[[36m(RayWorkerWrapper pid=2384)^[[0m INFO 02-06 08:08:12 spec_decode_worker.py:339] [Speculative Decoding] Use MQA scorer for scoring proposals.
INFO 02-06 08:08:29 model_runner.py:1116] Loading model weights took 2.3185 GB
INFO 02-06 08:08:29 spec_decode_worker.py:339] [Speculative Decoding] Use MQA scorer for scoring proposals.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/vllm-workspace/run2.py", line 10, in <module>
[rank0]:     llm = LLM(
[rank0]:           ^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 1039, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/llm.py", line 240, in __init__
[rank0]:     self.llm_engine = self.engine_class.from_engine_args(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 482, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 271, in __init__
[rank0]:     self.model_executor = executor_class(vllm_config=vllm_config, )
[rank0]:                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 260, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 49, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/ray_distributed_executor.py", line 88, in _init_executor
[rank0]:     self._init_workers_ray(placement_group)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/ray_distributed_executor.py", line 343, in _init_workers_ray
[rank0]:     self._run_workers("init_device")
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/ray_distributed_executor.py", line 469, in _run_workers
[rank0]:     ray_worker_outputs = ray.get(ray_worker_outputs)
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 2772, in get
[rank0]:     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
[rank0]:                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 919, in get_objects
[rank0]:     raise value.as_instanceof_cause()
[rank0]: ray.exceptions.RayTaskError(RuntimeError): ^[[36mray::RayWorkerWrapper.execute_method()^[[39m (pid=14177, ip=172.31.18.145, actor_id=ac50405ac53b631dcd36345f12000000, repr=<vllm.executor.ray_utils.RayWorkerWrapper object at 0x70f0dc40c200>)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 573, in execute_method
[rank0]:     raise e
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 564, in execute_method
[rank0]:     return run_method(target, method, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2208, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/spec_decode/spec_decode_worker.py", line 329, in init_device
[rank0]:     self.spec_decode_sampler.init_tensors(self.rank,
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/spec_decode_base_sampler.py", line 54, in init_tensors
[rank0]:     self.num_accepted_tokens = torch.tensor(0,
[rank0]:                                ^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: invalid device ordinal
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.                                                      

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions