Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,14 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
elif shard_id == "w2":
param_data[expert_id] = loaded_weight

def _load_w13_weight_scale(self, shard_dim: int,
loaded_weight: torch.Tensor,
param: torch.Tensor, tp_rank: int):
shard_size = param.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
param.copy_(loaded_weight)

def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
Expand Down Expand Up @@ -1123,7 +1131,12 @@ def weight_loader(self,
"weight_scale_2" in weight_name if uses_weight_scale_2 else
"weight_scale" in weight_name) or "input_scale" in weight_name

if per_tensor_conditions:
if "w13_weight_scale" in weight_name:
self._load_w13_weight_scale(shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
tp_rank=self.tp_rank)
elif per_tensor_conditions:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert (layer.weight_scale.shape[1] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
Expand Down
Loading