Skip to content

Commit 89ad99a

Browse files
committed
[TPU] add tpu_inference
Signed-off-by: Johnny Yang <[email protected]>
1 parent ddeec11 commit 89ad99a

File tree

3 files changed

+3
-12
lines changed

3 files changed

+3
-12
lines changed

requirements/tpu.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,5 @@ ray[data]
1212
setuptools==78.1.0
1313
nixl==0.3.0
1414
tpu_info==0.4.0
15-
16-
# Install torch_xla
17-
torch_xla[tpu, pallas]==2.8.0
15+
tpu-inference==0.11.1
16+
numba

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,3 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
9797
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
9898
assert dim == -1, "TPUs only support dim=-1 for all-gather."
9999
return xm.all_gather(input_, dim=dim)
100-
101-
102-
if USE_TPU_INFERENCE:
103-
from tpu_inference.distributed.device_communicators import (
104-
TpuCommunicator as TpuInferenceCommunicator,
105-
)
106-
107-
TpuCommunicator = TpuInferenceCommunicator # type: ignore

vllm/v1/worker/tpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,6 @@ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
351351

352352

353353
if USE_TPU_INFERENCE:
354-
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
354+
from tpu_inference.worker.tpu_worker_jax import TPUWorker as TpuInferenceWorker
355355

356356
TPUWorker = TpuInferenceWorker # type: ignore

0 commit comments

Comments
 (0)