Skip to content

Commit 9dc7c6c

Browse files
authored
[dbrx] refactor dbrx experts to extend FusedMoe class (vllm-project#8518)
1 parent ec4aaad commit 9dc7c6c

File tree

1 file changed

+51
-69
lines changed

1 file changed

+51
-69
lines changed

vllm/model_executor/models/dbrx.py

Lines changed: 51 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from vllm.attention import Attention, AttentionMetadata
88
from vllm.config import CacheConfig
99
from vllm.distributed import (get_tensor_model_parallel_rank,
10-
get_tensor_model_parallel_world_size,
11-
tensor_model_parallel_all_reduce)
12-
from vllm.model_executor.layers.fused_moe import fused_moe
10+
get_tensor_model_parallel_world_size)
11+
from vllm.model_executor.layers.fused_moe import FusedMoE
1312
from vllm.model_executor.layers.linear import (QKVParallelLinear,
1413
ReplicatedLinear,
1514
RowParallelLinear)
@@ -22,7 +21,6 @@
2221
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2322
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2423
from vllm.model_executor.sampling_metadata import SamplingMetadata
25-
from vllm.model_executor.utils import set_weight_attrs
2624
from vllm.sequence import IntermediateTensors
2725
from vllm.transformers_utils.configs.dbrx import DbrxConfig
2826

@@ -54,63 +52,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
5452
return router_logits
5553

5654

57-
class DbrxExperts(nn.Module):
58-
"""A tensor-parallel MoE implementation for DBRX.
59-
60-
Each expert's weights are sharded across all ranks and a fused MoE
61-
kernel is used for the forward pass, and finally we reduce the outputs
62-
across ranks.
63-
"""
55+
class DbrxExperts(FusedMoE):
6456

6557
def __init__(
6658
self,
6759
config: DbrxConfig,
6860
quant_config: Optional[QuantizationConfig] = None,
6961
params_dtype: Optional[torch.dtype] = None,
7062
):
71-
super().__init__()
63+
super().__init__(
64+
num_experts=config.ffn_config.moe_num_experts,
65+
top_k=config.ffn_config.moe_top_k,
66+
hidden_size=config.d_model,
67+
intermediate_size=config.ffn_config.ffn_hidden_size,
68+
params_dtype=params_dtype,
69+
reduce_results=True,
70+
renormalize=True,
71+
quant_config=quant_config,
72+
tp_size=get_tensor_model_parallel_world_size(),
73+
)
74+
self.config = config
7275
self.tp_size = get_tensor_model_parallel_world_size()
73-
self.num_total_experts = config.ffn_config.moe_num_experts
74-
self.top_k = config.ffn_config.moe_top_k
7576
self.d_model = config.d_model
76-
self.intermediate_size = (config.ffn_config.ffn_hidden_size //
77+
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
7778
self.tp_size)
7879

79-
if params_dtype is None:
80-
params_dtype = torch.get_default_dtype()
81-
self.params_dtype = params_dtype
82-
83-
self.router = DbrxRouter(config, self.params_dtype)
84-
self.ws = nn.Parameter(
85-
torch.empty(
86-
self.num_total_experts,
87-
2 * self.intermediate_size,
88-
self.d_model,
89-
device="cuda",
90-
dtype=self.params_dtype,
91-
))
92-
self.w2s = nn.Parameter(
93-
torch.empty(
94-
self.num_total_experts,
95-
self.d_model,
96-
self.intermediate_size,
97-
device="cuda",
98-
dtype=self.params_dtype,
99-
))
100-
101-
set_weight_attrs(
102-
self.ws,
103-
{
104-
"weight_loader": self.weight_loader,
105-
},
106-
)
107-
set_weight_attrs(
108-
self.w2s,
109-
{
110-
"weight_loader": self.weight_loader,
111-
},
112-
)
113-
80+
# Define custom weight loader for dbrx model
11481
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
11582
weight_name: str):
11683
tp_rank = get_tensor_model_parallel_rank()
@@ -140,26 +107,40 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
140107
).transpose(1, 2)
141108
param_data[:] = loaded_weight[:, :, shard]
142109

110+
111+
class DbrxMoE(nn.Module):
112+
"""A tensor-parallel MoE implementation for DBRX.
113+
114+
Each expert's weights are sharded across all ranks and a fused MoE
115+
kernel is used for the forward pass, and finally we reduce the outputs
116+
across ranks.
117+
"""
118+
119+
def __init__(
120+
self,
121+
config: DbrxConfig,
122+
quant_config: Optional[QuantizationConfig] = None,
123+
params_dtype: Optional[torch.dtype] = None,
124+
):
125+
super().__init__()
126+
self.d_model = config.d_model
127+
if params_dtype is None:
128+
params_dtype = torch.get_default_dtype()
129+
self.params_dtype = params_dtype
130+
131+
self.router = DbrxRouter(config, self.params_dtype)
132+
133+
self.experts = DbrxExperts(config=config,
134+
quant_config=quant_config,
135+
params_dtype=self.params_dtype)
136+
143137
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
144-
num_tokens, hidden_size = hidden_states.shape
138+
orig_shape = hidden_states.shape
145139
hidden_states = hidden_states.view(-1, self.d_model)
146140
# router_logits: (num_tokens, n_experts)
147141
router_logits = self.router(hidden_states)
148-
final_hidden_states = fused_moe(
149-
hidden_states,
150-
self.ws,
151-
self.w2s,
152-
router_logits,
153-
self.top_k,
154-
renormalize=True,
155-
inplace=True,
156-
)
157-
158-
if self.tp_size > 1:
159-
final_hidden_states = tensor_model_parallel_all_reduce(
160-
final_hidden_states)
161-
162-
return final_hidden_states.view(num_tokens, hidden_size)
142+
final_hidden_states = self.experts(hidden_states, router_logits)
143+
return final_hidden_states.view(orig_shape)
163144

164145

165146
class DbrxAttention(nn.Module):
@@ -288,7 +269,7 @@ def __init__(
288269
super().__init__()
289270
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
290271
quant_config)
291-
self.ffn = DbrxExperts(config, quant_config)
272+
self.ffn = DbrxMoE(config, quant_config)
292273

293274
def forward(
294275
self,
@@ -409,9 +390,10 @@ def sample(
409390
return next_tokens
410391

411392
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
393+
412394
expert_params_mapping = [(
413-
"ws" if weight_name in ["w1", "v1"] else "w2s",
414-
f"experts.mlp.{weight_name}",
395+
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
396+
f"mlp.{weight_name}",
415397
) for weight_name in ["w1", "v1", "w2"]]
416398
params_dict = dict(self.named_parameters(remove_duplicate=False))
417399
for name, loaded_weight in weights:

0 commit comments

Comments
 (0)