|
7 | 7 | from vllm.attention import Attention, AttentionMetadata
|
8 | 8 | from vllm.config import CacheConfig
|
9 | 9 | from vllm.distributed import (get_tensor_model_parallel_rank,
|
10 |
| - get_tensor_model_parallel_world_size, |
11 |
| - tensor_model_parallel_all_reduce) |
12 |
| -from vllm.model_executor.layers.fused_moe import fused_moe |
| 10 | + get_tensor_model_parallel_world_size) |
| 11 | +from vllm.model_executor.layers.fused_moe import FusedMoE |
13 | 12 | from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
14 | 13 | ReplicatedLinear,
|
15 | 14 | RowParallelLinear)
|
|
22 | 21 | DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
23 | 22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24 | 23 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
25 |
| -from vllm.model_executor.utils import set_weight_attrs |
26 | 24 | from vllm.sequence import IntermediateTensors
|
27 | 25 | from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
28 | 26 |
|
@@ -54,63 +52,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
54 | 52 | return router_logits
|
55 | 53 |
|
56 | 54 |
|
57 |
| -class DbrxExperts(nn.Module): |
58 |
| - """A tensor-parallel MoE implementation for DBRX. |
59 |
| -
|
60 |
| - Each expert's weights are sharded across all ranks and a fused MoE |
61 |
| - kernel is used for the forward pass, and finally we reduce the outputs |
62 |
| - across ranks. |
63 |
| - """ |
| 55 | +class DbrxExperts(FusedMoE): |
64 | 56 |
|
65 | 57 | def __init__(
|
66 | 58 | self,
|
67 | 59 | config: DbrxConfig,
|
68 | 60 | quant_config: Optional[QuantizationConfig] = None,
|
69 | 61 | params_dtype: Optional[torch.dtype] = None,
|
70 | 62 | ):
|
71 |
| - super().__init__() |
| 63 | + super().__init__( |
| 64 | + num_experts=config.ffn_config.moe_num_experts, |
| 65 | + top_k=config.ffn_config.moe_top_k, |
| 66 | + hidden_size=config.d_model, |
| 67 | + intermediate_size=config.ffn_config.ffn_hidden_size, |
| 68 | + params_dtype=params_dtype, |
| 69 | + reduce_results=True, |
| 70 | + renormalize=True, |
| 71 | + quant_config=quant_config, |
| 72 | + tp_size=get_tensor_model_parallel_world_size(), |
| 73 | + ) |
| 74 | + self.config = config |
72 | 75 | self.tp_size = get_tensor_model_parallel_world_size()
|
73 |
| - self.num_total_experts = config.ffn_config.moe_num_experts |
74 |
| - self.top_k = config.ffn_config.moe_top_k |
75 | 76 | self.d_model = config.d_model
|
76 |
| - self.intermediate_size = (config.ffn_config.ffn_hidden_size // |
| 77 | + self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // |
77 | 78 | self.tp_size)
|
78 | 79 |
|
79 |
| - if params_dtype is None: |
80 |
| - params_dtype = torch.get_default_dtype() |
81 |
| - self.params_dtype = params_dtype |
82 |
| - |
83 |
| - self.router = DbrxRouter(config, self.params_dtype) |
84 |
| - self.ws = nn.Parameter( |
85 |
| - torch.empty( |
86 |
| - self.num_total_experts, |
87 |
| - 2 * self.intermediate_size, |
88 |
| - self.d_model, |
89 |
| - device="cuda", |
90 |
| - dtype=self.params_dtype, |
91 |
| - )) |
92 |
| - self.w2s = nn.Parameter( |
93 |
| - torch.empty( |
94 |
| - self.num_total_experts, |
95 |
| - self.d_model, |
96 |
| - self.intermediate_size, |
97 |
| - device="cuda", |
98 |
| - dtype=self.params_dtype, |
99 |
| - )) |
100 |
| - |
101 |
| - set_weight_attrs( |
102 |
| - self.ws, |
103 |
| - { |
104 |
| - "weight_loader": self.weight_loader, |
105 |
| - }, |
106 |
| - ) |
107 |
| - set_weight_attrs( |
108 |
| - self.w2s, |
109 |
| - { |
110 |
| - "weight_loader": self.weight_loader, |
111 |
| - }, |
112 |
| - ) |
113 |
| - |
| 80 | + # Define custom weight loader for dbrx model |
114 | 81 | def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
115 | 82 | weight_name: str):
|
116 | 83 | tp_rank = get_tensor_model_parallel_rank()
|
@@ -140,26 +107,40 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
140 | 107 | ).transpose(1, 2)
|
141 | 108 | param_data[:] = loaded_weight[:, :, shard]
|
142 | 109 |
|
| 110 | + |
| 111 | +class DbrxMoE(nn.Module): |
| 112 | + """A tensor-parallel MoE implementation for DBRX. |
| 113 | +
|
| 114 | + Each expert's weights are sharded across all ranks and a fused MoE |
| 115 | + kernel is used for the forward pass, and finally we reduce the outputs |
| 116 | + across ranks. |
| 117 | + """ |
| 118 | + |
| 119 | + def __init__( |
| 120 | + self, |
| 121 | + config: DbrxConfig, |
| 122 | + quant_config: Optional[QuantizationConfig] = None, |
| 123 | + params_dtype: Optional[torch.dtype] = None, |
| 124 | + ): |
| 125 | + super().__init__() |
| 126 | + self.d_model = config.d_model |
| 127 | + if params_dtype is None: |
| 128 | + params_dtype = torch.get_default_dtype() |
| 129 | + self.params_dtype = params_dtype |
| 130 | + |
| 131 | + self.router = DbrxRouter(config, self.params_dtype) |
| 132 | + |
| 133 | + self.experts = DbrxExperts(config=config, |
| 134 | + quant_config=quant_config, |
| 135 | + params_dtype=self.params_dtype) |
| 136 | + |
143 | 137 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
144 |
| - num_tokens, hidden_size = hidden_states.shape |
| 138 | + orig_shape = hidden_states.shape |
145 | 139 | hidden_states = hidden_states.view(-1, self.d_model)
|
146 | 140 | # router_logits: (num_tokens, n_experts)
|
147 | 141 | router_logits = self.router(hidden_states)
|
148 |
| - final_hidden_states = fused_moe( |
149 |
| - hidden_states, |
150 |
| - self.ws, |
151 |
| - self.w2s, |
152 |
| - router_logits, |
153 |
| - self.top_k, |
154 |
| - renormalize=True, |
155 |
| - inplace=True, |
156 |
| - ) |
157 |
| - |
158 |
| - if self.tp_size > 1: |
159 |
| - final_hidden_states = tensor_model_parallel_all_reduce( |
160 |
| - final_hidden_states) |
161 |
| - |
162 |
| - return final_hidden_states.view(num_tokens, hidden_size) |
| 142 | + final_hidden_states = self.experts(hidden_states, router_logits) |
| 143 | + return final_hidden_states.view(orig_shape) |
163 | 144 |
|
164 | 145 |
|
165 | 146 | class DbrxAttention(nn.Module):
|
@@ -288,7 +269,7 @@ def __init__(
|
288 | 269 | super().__init__()
|
289 | 270 | self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
|
290 | 271 | quant_config)
|
291 |
| - self.ffn = DbrxExperts(config, quant_config) |
| 272 | + self.ffn = DbrxMoE(config, quant_config) |
292 | 273 |
|
293 | 274 | def forward(
|
294 | 275 | self,
|
@@ -409,9 +390,10 @@ def sample(
|
409 | 390 | return next_tokens
|
410 | 391 |
|
411 | 392 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
| 393 | + |
412 | 394 | expert_params_mapping = [(
|
413 |
| - "ws" if weight_name in ["w1", "v1"] else "w2s", |
414 |
| - f"experts.mlp.{weight_name}", |
| 395 | + "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", |
| 396 | + f"mlp.{weight_name}", |
415 | 397 | ) for weight_name in ["w1", "v1", "w2"]]
|
416 | 398 | params_dict = dict(self.named_parameters(remove_duplicate=False))
|
417 | 399 | for name, loaded_weight in weights:
|
|
0 commit comments