Skip to content

Commit d4bf085

Browse files
[MISC] add support custom_op check (vllm-project#8557)
Co-authored-by: youkaichao <[email protected]>
1 parent 0057894 commit d4bf085

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

vllm/distributed/parallel_state.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import vllm.envs as envs
3737
from vllm.logger import init_logger
3838
from vllm.platforms import current_platform
39+
from vllm.utils import supports_custom_op
3940

4041

4142
@dataclass
@@ -95,32 +96,33 @@ def _register_group(group: "GroupCoordinator") -> None:
9596
_groups[group.unique_name] = weakref.ref(group) # type: ignore
9697

9798

98-
@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
99-
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
100-
assert group_name in _groups, f"Group {group_name} is not found."
101-
group = _groups[group_name]()
102-
if group is None:
103-
raise ValueError(f"Group {group_name} is destroyed.")
104-
group._all_reduce(tensor)
99+
if supports_custom_op():
105100

101+
@torch.library.custom_op("vllm::inplace_all_reduce",
102+
mutates_args=["tensor"])
103+
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
104+
assert group_name in _groups, f"Group {group_name} is not found."
105+
group = _groups[group_name]()
106+
if group is None:
107+
raise ValueError(f"Group {group_name} is destroyed.")
108+
group._all_reduce(tensor)
106109

107-
@inplace_all_reduce.register_fake
108-
def _(tensor: torch.Tensor, group_name: str) -> None:
109-
return
110-
111-
112-
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
113-
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
114-
assert group_name in _groups, f"Group {group_name} is not found."
115-
group = _groups[group_name]()
116-
if group is None:
117-
raise ValueError(f"Group {group_name} is destroyed.")
118-
return group._all_reduce(tensor)
110+
@inplace_all_reduce.register_fake
111+
def _(tensor: torch.Tensor, group_name: str) -> None:
112+
return
119113

114+
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
115+
def outplace_all_reduce(tensor: torch.Tensor,
116+
group_name: str) -> torch.Tensor:
117+
assert group_name in _groups, f"Group {group_name} is not found."
118+
group = _groups[group_name]()
119+
if group is None:
120+
raise ValueError(f"Group {group_name} is destroyed.")
121+
return group._all_reduce(tensor)
120122

121-
@outplace_all_reduce.register_fake
122-
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
123-
return torch.empty_like(tensor)
123+
@outplace_all_reduce.register_fake
124+
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
125+
return torch.empty_like(tensor)
124126

125127

126128
class GroupCoordinator:
@@ -335,6 +337,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
335337
if self.world_size == 1:
336338
return input_
337339

340+
if not supports_custom_op():
341+
return self._all_reduce(input_)
342+
338343
if self.tpu_communicator is not None and \
339344
not self.tpu_communicator.disabled:
340345
# TPU handles Dynamo with its own logic.

vllm/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,12 @@ def supports_dynamo() -> bool:
12451245
return base_torch_version >= Version("2.4.0")
12461246

12471247

1248+
# Some backends use pytorch version < 2.4.0 which doesn't
1249+
# support `torch.library.custom_op`.
1250+
def supports_custom_op() -> bool:
1251+
return hasattr(torch.library, "custom_op")
1252+
1253+
12481254
class AtomicCounter:
12491255
"""An atomic, thread-safe counter"""
12501256

0 commit comments

Comments
 (0)