|
1 | 1 | import torch
|
2 |
| -from typing import Optional, Tuple |
| 2 | +from typing import List, Optional, Tuple |
3 | 3 |
|
4 | 4 | import pplx_kernels as pplx
|
5 | 5 | import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
| 6 | +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize |
6 | 7 |
|
7 | 8 |
|
| 9 | +# Note use: layer.get_all_to_all() to get an AllToAll instance |
| 10 | +# The max_num_tokens, world_size and dp_size must be the same |
| 11 | +# as the ones used to create the AllToAll. Unfortunately, there's |
| 12 | +# no way(?) to extract this info from AllToAll |
8 | 13 | class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
|
9 |
| - def __init__(self, a2a: pplx.AllToAll): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + a2a: pplx.AllToAll, |
| 17 | + max_num_tokens: int, |
| 18 | + world_size: int, |
| 19 | + dp_size: int, |
| 20 | + block_shape: Optional[List[int]] = None): |
10 | 21 | super().__init__()
|
11 | 22 | self.a2a = a2a
|
| 23 | + self.block_shape = block_shape |
| 24 | + self.dp_num_tokens = max_num_tokens * (world_size // dp_size) |
12 | 25 |
|
13 | 26 | def dispatch(
|
14 | 27 | self,
|
15 | 28 | a1: torch.Tensor,
|
16 | 29 | a1_scale: Optional[torch.Tensor],
|
17 | 30 | a2_scale: Optional[torch.Tensor],
|
18 |
| - topk_ids: torch.Tensor, |
| 31 | + rank_topk_ids: torch.Tensor, |
19 | 32 | num_experts: int,
|
20 | 33 | expert_map: Optional[torch.Tensor],
|
21 |
| - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: |
| 34 | + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| 35 | + # Is this always going to be a1.device? |
| 36 | + device = a1.device |
| 37 | + |
| 38 | + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( |
| 39 | + a2_scale.numel() != 1 if a2_scale is not None else False) |
| 40 | + |
| 41 | + a1q, a1q_scale = _fp8_quantize( |
| 42 | + a1, |
| 43 | + a1_scale, |
| 44 | + self.block_shape, |
| 45 | + per_act_token, |
| 46 | + ) |
| 47 | + |
| 48 | + expert_num_tokens = torch.empty( |
| 49 | + num_experts, |
| 50 | + dtype=torch.int32, |
| 51 | + device=device, |
| 52 | + ) |
| 53 | + |
| 54 | + expert_x = torch.empty( |
| 55 | + (num_experts, self.dp_num_tokens, a1q.shape[-1]), |
| 56 | + dtype=a1q.dtype, |
| 57 | + device=device, |
| 58 | + ) |
| 59 | + |
| 60 | + expert_x_scale: torch.Tensor | None = None |
| 61 | + if a1q.dtype.itemsize == 1: |
| 62 | + float32_size = torch.float32.itemsize |
| 63 | + block_size = (self.block_shape[0] if self.block_shape is not None else 1) * float32_size |
| 64 | + expert_x_scale = torch.empty( |
| 65 | + ( |
| 66 | + num_experts, |
| 67 | + expert_x.size(1), |
| 68 | + (expert_x.size(2) + block_size - 1) // block_size, |
| 69 | + ), |
| 70 | + dtype=torch.float32, |
| 71 | + device=device, |
| 72 | + ) |
| 73 | + |
| 74 | + # This argument is optional |
| 75 | + bound_m = torch.tensor([a1q.shape[0]], dtype=torch.uint32, device=device) |
| 76 | + |
22 | 77 | self.a2a.dispatch(
|
23 |
| - out_expert_num_tokens, # torch.Tensor, |
24 |
| - out_expert_x, # torch.Tensor, |
25 |
| - out_expert_x_scale, # torch.Tensor | None, |
26 |
| - dp_x, # torch.Tensor, |
27 |
| - dp_x_scale, # torch.Tensor | None, |
28 |
| - indices, # torch.Tensor, |
29 |
| - bound_m, # torch.Tensor | None, |
30 |
| - do_send, # bool = True, |
31 |
| - do_recv, # bool = True, |
| 78 | + out_expert_num_tokens=expert_num_tokens, |
| 79 | + out_expert_x=expert_x, |
| 80 | + out_expert_x_scale=expert_x_scale, |
| 81 | + dp_x=a1q, |
| 82 | + dp_x_scale=a1q_scale, |
| 83 | + indices=rank_topk_ids, |
| 84 | + bound_m=bound_m, |
32 | 85 | )
|
33 |
| - return 1q, a1q_scale, topk_ids |
| 86 | + return expert_x, expert_x_scale |
34 | 87 |
|
35 | 88 | def combine(
|
36 | 89 | self,
|
37 | 90 | output: torch.Tensor,
|
38 | 91 | fused_expert_output: torch.Tensor,
|
39 | 92 | topk_weights: torch.Tensor,
|
| 93 | + topk_ids: torch.Tensor, |
40 | 94 | ) -> None:
|
41 |
| - self.a2a.combine( |
42 |
| - out_tokens, #: torch.Tensor, |
43 |
| - indices, #: torch.Tensor, |
44 |
| - weights, #: torch.Tensor, |
45 |
| - expert_y, #: torch.Tensor, |
46 |
| - bound_m, #: torch.Tensor | None, |
47 |
| - do_send, #: bool = True, |
48 |
| - do_recv, #: bool = True, |
49 |
| - ) |
| 95 | + # This argument is optional |
| 96 | + bound_m = torch.tensor([output.shape[0]], dtype=torch.uint32, device=output.device) |
50 | 97 |
|
| 98 | + # TODO assert output is the proper size |
51 | 99 |
|
52 |
| -# singleton-ish |
53 |
| -def get_a2a( |
54 |
| - max_num_tokens: int, |
55 |
| - num_experts: int, |
56 |
| - experts_per_token: int, |
57 |
| - rank: int, |
58 |
| - world_size: int, |
59 |
| - dp_size: int, |
60 |
| - hidden_dim: int, |
61 |
| - hidden_dim_bytes: int, |
62 |
| - hidden_dim_scale_bytes: int, |
63 |
| -) -> pplx.AllToAll: |
64 |
| - pass |
| 100 | + self.a2a.combine( |
| 101 | + out_tokens=output, |
| 102 | + indices=topk_ids, |
| 103 | + weights=topk_weights, |
| 104 | + expert_y=fused_expert_output, |
| 105 | + bound_m=bound_m |
| 106 | + ) |
0 commit comments