Skip to content

Commit c42590f

Browse files
authored
[Hardware] [Intel GPU] refactor xpu worker/executor (vllm-project#7686)
1 parent aae6927 commit c42590f

File tree

3 files changed

+26
-28
lines changed

3 files changed

+26
-28
lines changed

vllm/executor/xpu_executor.py

+15-23
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Tuple, Union
22

33
import torch
44

55
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
6-
ModelConfig, ParallelConfig, PromptAdapterConfig,
7-
SchedulerConfig, SpeculativeConfig)
6+
ModelConfig, ObservabilityConfig, ParallelConfig,
7+
PromptAdapterConfig, SchedulerConfig,
8+
SpeculativeConfig)
89
from vllm.executor.executor_base import ExecutorAsyncBase
910
from vllm.executor.gpu_executor import GPUExecutor
1011
from vllm.logger import init_logger
11-
from vllm.sequence import ExecuteModelRequest, SamplerOutput
12+
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
1213
from vllm.utils import make_async
13-
from vllm.worker.worker_base import WorkerWrapperBase
1414

1515
logger = init_logger(__name__)
1616

@@ -30,6 +30,7 @@ def __init__(
3030
lora_config: Optional[LoRAConfig],
3131
prompt_adapter_config: Optional[PromptAdapterConfig],
3232
speculative_config: Optional[SpeculativeConfig],
33+
observability_config: Optional[ObservabilityConfig],
3334
) -> None:
3435
assert device_config.device_type == "xpu"
3536
assert (not speculative_config
@@ -46,32 +47,23 @@ def __init__(
4647
self.device_config = device_config
4748
self.prompt_adapter_config = prompt_adapter_config
4849
self.speculative_config = None
50+
self.observability_config = observability_config
4951

5052
# Instantiate the worker and load the model to GPU.
5153
self._init_executor()
5254

53-
def _create_worker(self,
54-
local_rank: int = 0,
55-
rank: int = 0,
56-
distributed_init_method: Optional[str] = None):
57-
if self.speculative_config is None:
58-
worker_module_name = "vllm.worker.xpu_worker"
59-
worker_class_name = "XPUWorker"
60-
else:
55+
def _get_worker_module_and_class(self) -> Tuple[str, str]:
56+
if self.speculative_config is not None:
6157
raise NotImplementedError(
6258
"XPU does not support speculative decoding")
63-
64-
wrapper = WorkerWrapperBase(
65-
worker_module_name=worker_module_name,
66-
worker_class_name=worker_class_name,
67-
)
68-
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
69-
distributed_init_method))
70-
return wrapper.worker
59+
else:
60+
worker_module_name = "vllm.worker.xpu_worker"
61+
worker_class_name = "XPUWorker"
62+
return (worker_module_name, worker_class_name)
7163

7264
def execute_model(
73-
self,
74-
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
65+
self, execute_model_req: ExecuteModelRequest
66+
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
7567
output = self.driver_worker.execute_model(execute_model_req)
7668
return output
7769

vllm/worker/xpu_model_runner.py

-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def load_model(self) -> None:
137137
device_config=self.device_config,
138138
load_config=self.load_config,
139139
lora_config=self.lora_config,
140-
multimodal_config=self.multimodal_config,
141140
parallel_config=self.parallel_config,
142141
scheduler_config=self.scheduler_config,
143142
cache_config=self.cache_config,

vllm/worker/xpu_worker.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import torch.distributed
1010

1111
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
12-
ModelConfig, MultiModalConfig, ParallelConfig,
13-
PromptAdapterConfig, SchedulerConfig,
12+
ModelConfig, MultiModalConfig, ObservabilityConfig,
13+
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
1414
SpeculativeConfig)
1515
from vllm.distributed import (ensure_model_parallel_initialized,
1616
init_distributed_environment)
@@ -50,6 +50,7 @@ def __init__(
5050
speculative_config: Optional[SpeculativeConfig] = None,
5151
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
5252
is_driver_worker: bool = False,
53+
observability_config: Optional[ObservabilityConfig] = None,
5354
) -> None:
5455
assert device_config.device_type == "xpu"
5556
assert is_xpu()
@@ -67,8 +68,10 @@ def __init__(
6768
self.lora_config = lora_config
6869
self.prompt_adapter_config = prompt_adapter_config
6970
self.is_driver_worker = is_driver_worker
70-
if self.is_driver_worker:
71-
assert self.rank == 0, "The driver worker must have rank 0."
71+
self.observability_config = observability_config
72+
if parallel_config and is_driver_worker:
73+
assert rank % parallel_config.tensor_parallel_size == 0, \
74+
"Driver worker should be rank 0 of tensor parallel group."
7275

7376
self.multimodal_config = multimodal_config
7477

@@ -183,7 +186,11 @@ def init_worker_distributed_environment(self) -> None:
183186
# dependency (libdrm and drm headers) on your system.
184187
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
185188
"sockets")
189+
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
190+
str(parallel_config.world_size))
186191
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
192+
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
193+
os.environ["LOCAL_RANK"] = str(self.local_rank)
187194
init_distributed_environment(
188195
world_size=parallel_config.world_size,
189196
rank=rank,

0 commit comments

Comments
 (0)