Skip to content

Commit 12c6009

Browse files
committed
[TPU] add tpu_inference
Signed-off-by: Johnny Yang <[email protected]>
1 parent 1b622de commit 12c6009

File tree

4 files changed

+3
-13
lines changed

4 files changed

+3
-13
lines changed

requirements/tpu.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,4 @@ ray[data]
1212
setuptools==78.1.0
1313
nixl==0.3.0
1414
tpu_info==0.4.0
15-
16-
# Install torch_xla
17-
torch_xla[tpu, pallas]==2.8.0
15+
tpu-inference

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,3 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
9797
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
9898
assert dim == -1, "TPUs only support dim=-1 for all-gather."
9999
return xm.all_gather(input_, dim=dim)
100-
101-
102-
if USE_TPU_INFERENCE:
103-
from tpu_inference.distributed.device_communicators import (
104-
TpuCommunicator as TpuInferenceCommunicator,
105-
)
106-
107-
TpuCommunicator = TpuInferenceCommunicator # type: ignore

vllm/platforms/tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def check_max_model_len(cls, max_model_len: int) -> int:
262262

263263

264264
try:
265-
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
265+
from tpu_inference.platforms.tpu_platforms import TpuPlatform as TpuInferencePlatform
266266

267267
TpuPlatform = TpuInferencePlatform # type: ignore
268268
USE_TPU_INFERENCE = True

vllm/v1/worker/tpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,6 @@ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
350350

351351

352352
if USE_TPU_INFERENCE:
353-
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
353+
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
354354

355355
TPUWorker = TpuInferenceWorker # type: ignore

0 commit comments

Comments
 (0)