1
- from typing import List , Optional
1
+ from typing import List , Optional , Tuple , Union
2
2
3
3
import torch
4
4
5
5
from vllm .config import (CacheConfig , DeviceConfig , LoadConfig , LoRAConfig ,
6
- ModelConfig , ParallelConfig , PromptAdapterConfig ,
7
- SchedulerConfig , SpeculativeConfig )
6
+ ModelConfig , ObservabilityConfig , ParallelConfig ,
7
+ PromptAdapterConfig , SchedulerConfig ,
8
+ SpeculativeConfig )
8
9
from vllm .executor .executor_base import ExecutorAsyncBase
9
10
from vllm .executor .gpu_executor import GPUExecutor
10
11
from vllm .logger import init_logger
11
- from vllm .sequence import ExecuteModelRequest , SamplerOutput
12
+ from vllm .sequence import ExecuteModelRequest , PoolerOutput , SamplerOutput
12
13
from vllm .utils import make_async
13
- from vllm .worker .worker_base import WorkerWrapperBase
14
14
15
15
logger = init_logger (__name__ )
16
16
@@ -30,6 +30,7 @@ def __init__(
30
30
lora_config : Optional [LoRAConfig ],
31
31
prompt_adapter_config : Optional [PromptAdapterConfig ],
32
32
speculative_config : Optional [SpeculativeConfig ],
33
+ observability_config : Optional [ObservabilityConfig ],
33
34
) -> None :
34
35
assert device_config .device_type == "xpu"
35
36
assert (not speculative_config
@@ -46,32 +47,23 @@ def __init__(
46
47
self .device_config = device_config
47
48
self .prompt_adapter_config = prompt_adapter_config
48
49
self .speculative_config = None
50
+ self .observability_config = observability_config
49
51
50
52
# Instantiate the worker and load the model to GPU.
51
53
self ._init_executor ()
52
54
53
- def _create_worker (self ,
54
- local_rank : int = 0 ,
55
- rank : int = 0 ,
56
- distributed_init_method : Optional [str ] = None ):
57
- if self .speculative_config is None :
58
- worker_module_name = "vllm.worker.xpu_worker"
59
- worker_class_name = "XPUWorker"
60
- else :
55
+ def _get_worker_module_and_class (self ) -> Tuple [str , str ]:
56
+ if self .speculative_config is not None :
61
57
raise NotImplementedError (
62
58
"XPU does not support speculative decoding" )
63
-
64
- wrapper = WorkerWrapperBase (
65
- worker_module_name = worker_module_name ,
66
- worker_class_name = worker_class_name ,
67
- )
68
- wrapper .init_worker (** self ._get_worker_kwargs (local_rank , rank ,
69
- distributed_init_method ))
70
- return wrapper .worker
59
+ else :
60
+ worker_module_name = "vllm.worker.xpu_worker"
61
+ worker_class_name = "XPUWorker"
62
+ return (worker_module_name , worker_class_name )
71
63
72
64
def execute_model (
73
- self ,
74
- execute_model_req : ExecuteModelRequest ) -> List [SamplerOutput ]:
65
+ self , execute_model_req : ExecuteModelRequest
66
+ ) -> Optional [ List [Union [ SamplerOutput , PoolerOutput ]] ]:
75
67
output = self .driver_worker .execute_model (execute_model_req )
76
68
return output
77
69
0 commit comments