-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[TPU] add tpu_inference #27277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TPU] add tpu_inference #27277
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request transitions the TPU backend to use the tpu-inference library, updating dependencies and import paths accordingly. The changes correctly adapt vLLM to the new library version. However, I've identified a critical issue with the dependency management in requirements/tpu.txt. By removing torch_xla, the fallback mechanism for when tpu-inference is unavailable becomes non-functional, which could lead to confusing runtime errors. My review includes a suggestion to address this to ensure the system remains robust.
requirements/tpu.txt
Outdated
| tpu-inference==0.11.1 | ||
| numba |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing torch_xla from this requirements file makes the fallback path in the code fragile. The codebase contains logic to fall back to a torch_xla-based implementation if tpu-inference is not found or fails to import (e.g., in vllm/platforms/tpu.py).
With torch_xla removed from the dependencies, if tpu-inference is installed but fails to import for any reason (like a transitive dependency issue), the fallback will immediately fail with an ImportError for torch_xla. This creates a brittle setup and can cause confusing errors for users.
To ensure the fallback mechanism is robust, please re-add torch_xla to the dependencies. This ensures that even if the primary tpu-inference path fails, the system can gracefully fall back to the torch_xla-based implementation.
tpu-inference==0.11.1
numba
# Install torch_xla for fallback
torch_xla[tpu, pallas]==2.8.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
vllm/vllm/distributed/device_communicators/tpu_communicator.py
Lines 22 to 62 in 89ad99a
| if not USE_TPU_INFERENCE: | |
| logger.info("tpu_inference not found, using vLLM's TpuCommunicator") | |
| if current_platform.is_tpu(): | |
| import torch_xla | |
| import torch_xla.core.xla_model as xm | |
| import torch_xla.runtime as xr | |
| from torch_xla._internal import pjrt | |
| from torch_xla.distributed.xla_multiprocessing import ( | |
| create_optimized_replica_groups, | |
| ) | |
| if USE_RAY: | |
| from vllm.executor import ray_utils | |
| class TpuCommunicator(DeviceCommunicatorBase): | |
| def __init__( | |
| self, | |
| cpu_group: ProcessGroup, | |
| device: torch.device | None = None, | |
| device_group: ProcessGroup | None = None, | |
| unique_name: str = "", | |
| ): | |
| super().__init__(cpu_group, device, device_group, unique_name) | |
| # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node | |
| # must be used together. Therefore, the local rank and world size can | |
| # be simply calculated as follows. | |
| global_rank = self.global_rank | |
| global_world_size = self.global_world_size | |
| if USE_RAY: | |
| logger.info("TpuCommunicator initialized with RAY") | |
| # Calculate how many TPU nodes are in the current deployment. This | |
| # is the Ray placement group if it is deployed with Ray. Default | |
| # to the number of TPU nodes in the Ray cluster. The number of TPU | |
| # nodes is computed by the total number of TPUs divided by the | |
| # number of TPU accelerators per node, to account for clusters | |
| # with both CPUs and TPUs. | |
| num_nodes = ray_utils.get_num_tpu_nodes() | |
| num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() |
With tpu_inference installed USE_TPU_INFERENCE becomes True, but this module now always exposes the local TpuCommunicator instead of aliasing to the implementation in tpu_inference. The conditional import of ray_utils is still wrapped inside if not USE_TPU_INFERENCE, yet __init__ uses ray_utils whenever the distributed executor backend is Ray. In a TPU inference deployment that uses the Ray backend, ray_utils will be undefined and TpuCommunicator.__init__ will raise NameError before any communication is set up. The Ray utilities need to be imported regardless of USE_TPU_INFERENCE (or the Ray-specific logic needs a similar guard).
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
vllm/vllm/distributed/device_communicators/tpu_communicator.py
Lines 22 to 62 in 89ad99a
| if not USE_TPU_INFERENCE: | |
| logger.info("tpu_inference not found, using vLLM's TpuCommunicator") | |
| if current_platform.is_tpu(): | |
| import torch_xla | |
| import torch_xla.core.xla_model as xm | |
| import torch_xla.runtime as xr | |
| from torch_xla._internal import pjrt | |
| from torch_xla.distributed.xla_multiprocessing import ( | |
| create_optimized_replica_groups, | |
| ) | |
| if USE_RAY: | |
| from vllm.executor import ray_utils | |
| class TpuCommunicator(DeviceCommunicatorBase): | |
| def __init__( | |
| self, | |
| cpu_group: ProcessGroup, | |
| device: torch.device | None = None, | |
| device_group: ProcessGroup | None = None, | |
| unique_name: str = "", | |
| ): | |
| super().__init__(cpu_group, device, device_group, unique_name) | |
| # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node | |
| # must be used together. Therefore, the local rank and world size can | |
| # be simply calculated as follows. | |
| global_rank = self.global_rank | |
| global_world_size = self.global_world_size | |
| if USE_RAY: | |
| logger.info("TpuCommunicator initialized with RAY") | |
| # Calculate how many TPU nodes are in the current deployment. This | |
| # is the Ray placement group if it is deployed with Ray. Default | |
| # to the number of TPU nodes in the Ray cluster. The number of TPU | |
| # nodes is computed by the total number of TPUs divided by the | |
| # number of TPU accelerators per node, to account for clusters | |
| # with both CPUs and TPUs. | |
| num_nodes = ray_utils.get_num_tpu_nodes() | |
| num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() |
With tpu_inference installed USE_TPU_INFERENCE becomes True, but this module now always exposes the local TpuCommunicator instead of aliasing to the implementation in tpu_inference. The conditional import of ray_utils is still wrapped inside if not USE_TPU_INFERENCE, yet __init__ uses ray_utils whenever the distributed executor backend is Ray. In a TPU inference deployment that uses the Ray backend, ray_utils will be undefined and TpuCommunicator.__init__ will raise NameError before any communication is set up. The Ray utilities need to be imported regardless of USE_TPU_INFERENCE (or the Ray-specific logic needs a similar guard).
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
89ad99a to
12c6009
Compare
Signed-off-by: Johnny Yang <[email protected]>
12c6009 to
cb27073
Compare
Purpose
Add tpu-inference for vllm TPU.
Test Plan
E2E tests on v6e.
Test Result
Successful
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.