|
36 | 36 | import vllm.envs as envs
|
37 | 37 | from vllm.logger import init_logger
|
38 | 38 | from vllm.platforms import current_platform
|
| 39 | +from vllm.utils import supports_custom_op |
39 | 40 |
|
40 | 41 |
|
41 | 42 | @dataclass
|
@@ -95,32 +96,33 @@ def _register_group(group: "GroupCoordinator") -> None:
|
95 | 96 | _groups[group.unique_name] = weakref.ref(group) # type: ignore
|
96 | 97 |
|
97 | 98 |
|
98 |
| -@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) |
99 |
| -def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: |
100 |
| - assert group_name in _groups, f"Group {group_name} is not found." |
101 |
| - group = _groups[group_name]() |
102 |
| - if group is None: |
103 |
| - raise ValueError(f"Group {group_name} is destroyed.") |
104 |
| - group._all_reduce(tensor) |
| 99 | +if supports_custom_op(): |
105 | 100 |
|
| 101 | + @torch.library.custom_op("vllm::inplace_all_reduce", |
| 102 | + mutates_args=["tensor"]) |
| 103 | + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: |
| 104 | + assert group_name in _groups, f"Group {group_name} is not found." |
| 105 | + group = _groups[group_name]() |
| 106 | + if group is None: |
| 107 | + raise ValueError(f"Group {group_name} is destroyed.") |
| 108 | + group._all_reduce(tensor) |
106 | 109 |
|
107 |
| -@inplace_all_reduce.register_fake |
108 |
| -def _(tensor: torch.Tensor, group_name: str) -> None: |
109 |
| - return |
110 |
| - |
111 |
| - |
112 |
| -@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) |
113 |
| -def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: |
114 |
| - assert group_name in _groups, f"Group {group_name} is not found." |
115 |
| - group = _groups[group_name]() |
116 |
| - if group is None: |
117 |
| - raise ValueError(f"Group {group_name} is destroyed.") |
118 |
| - return group._all_reduce(tensor) |
| 110 | + @inplace_all_reduce.register_fake |
| 111 | + def _(tensor: torch.Tensor, group_name: str) -> None: |
| 112 | + return |
119 | 113 |
|
| 114 | + @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) |
| 115 | + def outplace_all_reduce(tensor: torch.Tensor, |
| 116 | + group_name: str) -> torch.Tensor: |
| 117 | + assert group_name in _groups, f"Group {group_name} is not found." |
| 118 | + group = _groups[group_name]() |
| 119 | + if group is None: |
| 120 | + raise ValueError(f"Group {group_name} is destroyed.") |
| 121 | + return group._all_reduce(tensor) |
120 | 122 |
|
121 |
| -@outplace_all_reduce.register_fake |
122 |
| -def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: |
123 |
| - return torch.empty_like(tensor) |
| 123 | + @outplace_all_reduce.register_fake |
| 124 | + def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: |
| 125 | + return torch.empty_like(tensor) |
124 | 126 |
|
125 | 127 |
|
126 | 128 | class GroupCoordinator:
|
@@ -335,6 +337,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
335 | 337 | if self.world_size == 1:
|
336 | 338 | return input_
|
337 | 339 |
|
| 340 | + if not supports_custom_op(): |
| 341 | + return self._all_reduce(input_) |
| 342 | + |
338 | 343 | if self.tpu_communicator is not None and \
|
339 | 344 | not self.tpu_communicator.disabled:
|
340 | 345 | # TPU handles Dynamo with its own logic.
|
|
0 commit comments