Skip to content

Conversation

vekoada
Copy link

@vekoada vekoada commented Aug 20, 2025

Summary

This PR improves the MXFP4 dequantization routine by replacing the naive nibble-splitting implementation with a LUT-based vectorized approach. This is a drop-in change that preserves identical numerical outputs but significantly improves runtime efficiency and reduces memory pressure.

Motivation

The current dequantization path:

  • Performs two int64 index computations per chunk (idx_lo, idx_hi)
  • Allocates larger-than-necessary intermediate temporaries

This leads to unnecessary overhead and higher peak memory use, especially for large checkpoint tensors.

Changes

  • Replaced nibble-splitting logic with a precomputed [256, 2] LUT for vectorized lookup
  • Reduced int64 temporaries from two per chunk to one
  • Preserved output shape and dtype exactly (bitwise identical results)

Performance Results

Benchmarked on T4 with representative checkpoint tensors:

Implementation Time Peak Memory
Default 0.677s 11.40 GB
Optimized 0.514s 11.02 GB

~24% faster runtime
~3% lower peak memory

This change does not introduce new features. It corrects inefficiencies in the current dequantization path. The outputs remain identical and the only difference is improved speed and memory efficiency.

@vekoada
Copy link
Author

vekoada commented Aug 20, 2025

Extrapolating the ~160ms saved on this tensor across all expert weights suggests a ~7s reduction in total model load time. And the approx. 380 MB of peak VRAM saved is additional room for more stable loads on consumer GPUs.

@vekoada vekoada changed the title fix(weights): optimize MXFP4 dequantization for speed and memory perf(weights): ~24% speedup and ~3% peak memory reduction for MXFP4 dequantization Aug 20, 2025
Copy link

@Sahil3378 Sahil3378 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate function definitions — get, _get_tensor, _get_mxfp4_tensor are pasted twice.

Duplicate argument — you have two rows_per_chunk definitions:


torch.ldexp(sub, exp, out=sub)
del idx_lo, idx_hi, blk, exp

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch, math
from safetensors import safe_open # assuming safetensors is used

FP4_VALUES = [...] # lookup table values must be defined

class TensorLoader:
def init(self, tensor_name_to_file, device="cpu"):
self.tensor_name_to_file = tensor_name_to_file
self.device_str = device

def get(self, name: str) -> torch.Tensor:
    match PARAM_NAME_MAP.get(name, name):
        case (blocks_name, scales_name):
            # MoE weights: in block-based MXFP4 format
            return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16)
        case tensor_name:
            # MoE biases and other weights
            return self._get_tensor(tensor_name)

def _get_tensor(self, name: str) -> torch.Tensor:
    assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint."
    with safe_open(self.tensor_name_to_file[name], framework="pt", device=self.device_str) as f:
        return f.get_tensor(name)

def _get_mxfp4_tensor(
    self,
    blocks_name: str,
    scales_name: str,
    *,
    dtype: torch.dtype = torch.bfloat16,
    rows_per_chunk: int = 12288 * 512,  # keep only one definition
) -> torch.Tensor:

    # Load blocks + scales
    blocks = self._get_tensor(blocks_name)
    scales = self._get_tensor(scales_name).to(torch.int32) - 127

    assert blocks.shape[:-1] == scales.shape, (
        f"{blocks.shape=} does not match {scales.shape=}"
    )

    # Base LUT for FP4 decoding
    lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)

    *prefix_shape, G, B = blocks.shape
    rows_total = math.prod(prefix_shape) * G

    blocks = blocks.reshape(rows_total, B)
    scales = scales.reshape(rows_total, 1)

    # Output tensor
    out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)

    for r0 in range(0, rows_total, rows_per_chunk):
        r1 = min(r0 + rows_per_chunk, rows_total)

        blk = blocks[r0:r1]
        exp = scales[r0:r1]

        # split 8-bit into 2x 4-bit (nibbles)
        idx_lo = (blk & 0x0F).to(torch.long)
        idx_hi = (blk >> 4).to(torch.long)

        sub = out[r0:r1]
        sub[:, 0::2] = lut[idx_lo]
        sub[:, 1::2] = lut[idx_hi]

        # Apply exponent scaling
        torch.ldexp(sub, exp, out=sub)

        # cleanup
        del idx_lo, idx_hi, blk, exp

    return out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants