Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions gpt_oss/torch/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,49 +71,44 @@ def _get_mxfp4_tensor(
scales_name: str,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 16384 * 512,
rows_per_chunk: int = 12288 * 512,
) -> torch.Tensor:
assert blocks_name in self.tensor_name_to_file, (
f"Blocks tensor {blocks_name} not found in checkpoint."
)
assert scales_name in self.tensor_name_to_file, (
f"Scales tensor {scales_name} not found in checkpoint."
)

blocks = self._get_tensor(blocks_name)
scales = self._get_tensor(scales_name).to(torch.int32) - 127

assert blocks.shape[:-1] == scales.shape, (
f"{blocks.shape=} does not match {scales.shape=}"
)

lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)


*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G

blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)

rows_total = math.prod(prefix_shape) * G

# Build a 256x2 LUT once (device+dtype of blocks)
base_lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
byte_vals = torch.arange(256, device=blocks.device)
lut_full = torch.stack(
(base_lut[byte_vals & 0x0F], base_lut[byte_vals >> 4]),
dim=1, # [256, 2] -> [low, high]
)

blocks_flat = blocks.reshape(rows_total, B) # reshape (not view)
scales_flat = scales.reshape(rows_total, 1)

out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)

for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)

blk = blocks[r0:r1]
exp = scales[r0:r1]

# nibble indices -> int64
idx_lo = (blk & 0x0F).to(torch.long)
idx_hi = (blk >> 4).to(torch.long)

sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
sub[:, 1::2] = lut[idx_hi]

torch.ldexp(sub, exp, out=sub)
del idx_lo, idx_hi, blk, exp

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch, math
from safetensors import safe_open # assuming safetensors is used

FP4_VALUES = [...] # lookup table values must be defined

class TensorLoader:
def init(self, tensor_name_to_file, device="cpu"):
self.tensor_name_to_file = tensor_name_to_file
self.device_str = device

def get(self, name: str) -> torch.Tensor:
    match PARAM_NAME_MAP.get(name, name):
        case (blocks_name, scales_name):
            # MoE weights: in block-based MXFP4 format
            return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16)
        case tensor_name:
            # MoE biases and other weights
            return self._get_tensor(tensor_name)

def _get_tensor(self, name: str) -> torch.Tensor:
    assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint."
    with safe_open(self.tensor_name_to_file[name], framework="pt", device=self.device_str) as f:
        return f.get_tensor(name)

def _get_mxfp4_tensor(
    self,
    blocks_name: str,
    scales_name: str,
    *,
    dtype: torch.dtype = torch.bfloat16,
    rows_per_chunk: int = 12288 * 512,  # keep only one definition
) -> torch.Tensor:

    # Load blocks + scales
    blocks = self._get_tensor(blocks_name)
    scales = self._get_tensor(scales_name).to(torch.int32) - 127

    assert blocks.shape[:-1] == scales.shape, (
        f"{blocks.shape=} does not match {scales.shape=}"
    )

    # Base LUT for FP4 decoding
    lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)

    *prefix_shape, G, B = blocks.shape
    rows_total = math.prod(prefix_shape) * G

    blocks = blocks.reshape(rows_total, B)
    scales = scales.reshape(rows_total, 1)

    # Output tensor
    out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)

    for r0 in range(0, rows_total, rows_per_chunk):
        r1 = min(r0 + rows_per_chunk, rows_total)

        blk = blocks[r0:r1]
        exp = scales[r0:r1]

        # split 8-bit into 2x 4-bit (nibbles)
        idx_lo = (blk & 0x0F).to(torch.long)
        idx_hi = (blk >> 4).to(torch.long)

        sub = out[r0:r1]
        sub[:, 0::2] = lut[idx_lo]
        sub[:, 1::2] = lut[idx_hi]

        # Apply exponent scaling
        torch.ldexp(sub, exp, out=sub)

        # cleanup
        del idx_lo, idx_hi, blk, exp

    return out

blk_chunk = blocks_flat[r0:r1].to(torch.long) # one int64 copy
mant = lut_full[blk_chunk].reshape(r1 - r0, B * 2)
torch.ldexp(mant, scales_flat[r0:r1], out=out[r0:r1])

return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)

def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16):
Expand Down