Skip to content

Commit fc3243d

Browse files
committed
initial pplx dispatch/combine class
Signed-off-by: Bill Nell <[email protected]>
1 parent bfa3497 commit fc3243d

File tree

5 files changed

+97
-59
lines changed

5 files changed

+97
-59
lines changed

tests/kernels/test_block_fp8.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def fp8_perm(m, idx):
362362
return m[idx, ...]
363363

364364

365-
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
365+
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
366366
M, K = a.shape
367367

368368
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
@@ -381,7 +381,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
381381
return a, a_s, m_indices, inv_perm
382382

383383

384-
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
384+
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
385385
M = topk_weight.shape[0]
386386
out = out[inv_perm, ...]
387387
tmp_out = out.view(-1, topk, K)
@@ -403,8 +403,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
403403

404404
a_q, a_s = per_token_group_quant_fp8(a, block_m)
405405

406-
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
407-
num_groups, topk, block_m)
406+
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
407+
num_groups, topk, block_m)
408408

409409
inter_out = torch.zeros((a_q.shape[0], N * 2),
410410
dtype=torch.bfloat16,
@@ -421,7 +421,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
421421
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
422422
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
423423

424-
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
424+
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
425425

426426
return final_out
427427

vllm/model_executor/layers/fused_moe/dispatch_combine.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def dispatch(
2121
topk_ids: torch.Tensor,
2222
num_experts: int,
2323
expert_map: Optional[torch.Tensor],
24-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
24+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
2525
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
2626
a2_scale.numel() != 1 if a2_scale is not None else False)
2727

@@ -31,14 +31,14 @@ def dispatch(
3131
self.block_shape,
3232
per_act_token,
3333
)
34-
return a1q, a1q_scale, topk_ids
34+
return a1q, a1q_scale
3535

3636
def combine(
3737
self,
3838
output: torch.Tensor,
3939
fused_expert_output: torch.Tensor,
4040
topk_weights: torch.Tensor,
41+
topk_ids: torch.Tensor,
4142
) -> None:
4243
_moe_unpermute_and_reduce(output, fused_expert_output, None,
4344
topk_weights)
44-

vllm/model_executor/layers/fused_moe/fused_moe.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,7 @@ def fused_moe(
14591459
block_shape=block_shape)
14601460

14611461

1462+
# TODO: merge with StandardDispatchCombine
14621463
class TritonDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
14631464

14641465
def __init__(self, use_fp8_w8a8: bool, block_shape: Optional[List[int]]):
@@ -1474,7 +1475,7 @@ def dispatch(
14741475
topk_ids: torch.Tensor,
14751476
num_experts: int,
14761477
expert_map: Optional[torch.Tensor],
1477-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
1478+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
14781479
if self.use_fp8_w8a8:
14791480
a1q, a1q_scale = _fp8_quantize(
14801481
a1,
@@ -1485,13 +1486,14 @@ def dispatch(
14851486
a1q = a1
14861487
a1q_scale = a1_scale
14871488

1488-
return a1q, a1q_scale, topk_ids
1489+
return a1q, a1q_scale
14891490

14901491
def combine(
14911492
self,
14921493
output: torch.Tensor,
14931494
fused_expert_output: torch.Tensor,
14941495
topk_weights: torch.Tensor,
1496+
topk_ids: torch.Tensor,
14951497
) -> None:
14961498
M, topk = topk_weights.shape
14971499
K = fused_expert_output.shape[-1]

vllm/model_executor/layers/fused_moe/modular_kernel.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99

1010
class FusedMoEQuantizeDispatchCombine(ABC):
1111

12-
# def __init__(self):
13-
# pass
14-
1512
@abstractmethod
1613
def dispatch(
1714
self,
@@ -21,11 +18,9 @@ def dispatch(
2118
topk_ids: torch.Tensor,
2219
num_experts: int,
2320
expert_map: Optional[torch.Tensor],
24-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
25-
# TODO: figure this out
21+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
2622
# returns (quantized+dispatched a,
27-
# quantized+dispatched a1_scales,
28-
# dispatched topk_ids)
23+
# quantized+dispatched a1_scales)
2924
raise NotImplementedError
3025

3126
@abstractmethod
@@ -34,16 +29,14 @@ def combine(
3429
output: torch.Tensor,
3530
fused_expert_output: torch.Tensor, # not reduced or weighted
3631
topk_weights: torch.Tensor,
32+
topk_ids: torch.Tensor,
3733
) -> None:
3834
raise NotImplementedError
3935

4036

4137
# store weights, etc. here
4238
class FusedMoEPermuteExpertsUnpermute(ABC):
4339

44-
# def __init__(self):
45-
# pass
46-
4740
@abstractmethod
4841
def workspace_shapes(
4942
self,
@@ -115,6 +108,7 @@ def forward(
115108
# two, so it's not "correct" to extract N or K from the trailing dimension of
116109
# w1 or w2. Similarly, some kernels transpose the weights, so this needs to
117110
# be kept in mind.
111+
# TODO: make this a method/utility function, e.g. problem_size(a, w1, w2, topk_ids, ...)
118112
M, _ = a1.shape
119113
E, N, _ = w1.shape
120114
K = w2.shape[1]
@@ -144,7 +138,7 @@ def forward(
144138
device=a1.device,
145139
dtype=workspace_dtype)
146140

147-
a1q, a1q_scale, dispatched_topk_ids = self.dispatch_combine.dispatch(
141+
a1q, a1q_scale = self.dispatch_combine.dispatch(
148142
a1,
149143
a1_scale,
150144
a2_scale,
@@ -157,7 +151,7 @@ def forward(
157151
a1q,
158152
w1,
159153
w2,
160-
dispatched_topk_ids,
154+
topk_ids,
161155
activation,
162156
global_num_experts,
163157
expert_map,
@@ -171,6 +165,6 @@ def forward(
171165
workspace2=workspace2,
172166
)
173167

174-
self.dispatch_combine.combine(output, fused_out, topk_weights)
168+
self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids)
175169

176170
return output
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,106 @@
11
import torch
2-
from typing import Optional, Tuple
2+
from typing import List, Optional, Tuple
33

44
import pplx_kernels as pplx
55
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
6+
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
67

78

9+
# Note use: layer.get_all_to_all() to get an AllToAll instance
10+
# The max_num_tokens, world_size and dp_size must be the same
11+
# as the ones used to create the AllToAll. Unfortunately, there's
12+
# no way(?) to extract this info from AllToAll
813
class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
9-
def __init__(self, a2a: pplx.AllToAll):
14+
def __init__(
15+
self,
16+
a2a: pplx.AllToAll,
17+
max_num_tokens: int,
18+
world_size: int,
19+
dp_size: int,
20+
block_shape: Optional[List[int]] = None):
1021
super().__init__()
1122
self.a2a = a2a
23+
self.block_shape = block_shape
24+
self.dp_num_tokens = max_num_tokens * (world_size // dp_size)
1225

1326
def dispatch(
1427
self,
1528
a1: torch.Tensor,
1629
a1_scale: Optional[torch.Tensor],
1730
a2_scale: Optional[torch.Tensor],
18-
topk_ids: torch.Tensor,
31+
rank_topk_ids: torch.Tensor,
1932
num_experts: int,
2033
expert_map: Optional[torch.Tensor],
21-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
34+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
35+
# Is this always going to be a1.device?
36+
device = a1.device
37+
38+
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
39+
a2_scale.numel() != 1 if a2_scale is not None else False)
40+
41+
a1q, a1q_scale = _fp8_quantize(
42+
a1,
43+
a1_scale,
44+
self.block_shape,
45+
per_act_token,
46+
)
47+
48+
expert_num_tokens = torch.empty(
49+
num_experts,
50+
dtype=torch.int32,
51+
device=device,
52+
)
53+
54+
expert_x = torch.empty(
55+
(num_experts, self.dp_num_tokens, a1q.shape[-1]),
56+
dtype=a1q.dtype,
57+
device=device,
58+
)
59+
60+
expert_x_scale: torch.Tensor | None = None
61+
if a1q.dtype.itemsize == 1:
62+
float32_size = torch.float32.itemsize
63+
block_size = (self.block_shape[0] if self.block_shape is not None else 1) * float32_size
64+
expert_x_scale = torch.empty(
65+
(
66+
num_experts,
67+
expert_x.size(1),
68+
(expert_x.size(2) + block_size - 1) // block_size,
69+
),
70+
dtype=torch.float32,
71+
device=device,
72+
)
73+
74+
# This argument is optional
75+
bound_m = torch.tensor([a1q.shape[0]], dtype=torch.uint32, device=device)
76+
2277
self.a2a.dispatch(
23-
out_expert_num_tokens, # torch.Tensor,
24-
out_expert_x, # torch.Tensor,
25-
out_expert_x_scale, # torch.Tensor | None,
26-
dp_x, # torch.Tensor,
27-
dp_x_scale, # torch.Tensor | None,
28-
indices, # torch.Tensor,
29-
bound_m, # torch.Tensor | None,
30-
do_send, # bool = True,
31-
do_recv, # bool = True,
78+
out_expert_num_tokens=expert_num_tokens,
79+
out_expert_x=expert_x,
80+
out_expert_x_scale=expert_x_scale,
81+
dp_x=a1q,
82+
dp_x_scale=a1q_scale,
83+
indices=rank_topk_ids,
84+
bound_m=bound_m,
3285
)
33-
return 1q, a1q_scale, topk_ids
86+
return expert_x, expert_x_scale
3487

3588
def combine(
3689
self,
3790
output: torch.Tensor,
3891
fused_expert_output: torch.Tensor,
3992
topk_weights: torch.Tensor,
93+
topk_ids: torch.Tensor,
4094
) -> None:
41-
self.a2a.combine(
42-
out_tokens, #: torch.Tensor,
43-
indices, #: torch.Tensor,
44-
weights, #: torch.Tensor,
45-
expert_y, #: torch.Tensor,
46-
bound_m, #: torch.Tensor | None,
47-
do_send, #: bool = True,
48-
do_recv, #: bool = True,
49-
)
95+
# This argument is optional
96+
bound_m = torch.tensor([output.shape[0]], dtype=torch.uint32, device=output.device)
5097

98+
# TODO assert output is the proper size
5199

52-
# singleton-ish
53-
def get_a2a(
54-
max_num_tokens: int,
55-
num_experts: int,
56-
experts_per_token: int,
57-
rank: int,
58-
world_size: int,
59-
dp_size: int,
60-
hidden_dim: int,
61-
hidden_dim_bytes: int,
62-
hidden_dim_scale_bytes: int,
63-
) -> pplx.AllToAll:
64-
pass
100+
self.a2a.combine(
101+
out_tokens=output,
102+
indices=topk_ids,
103+
weights=topk_weights,
104+
expert_y=fused_expert_output,
105+
bound_m=bound_m
106+
)

0 commit comments

Comments
 (0)