Skip to content

Conversation

faruknane
Copy link

@faruknane faruknane commented Sep 21, 2025

Hi,

I have worked on improving replace function and added layer analytics to FP-Quant. I encountered a bug which I have fixed. I also demonstrated how to quickly quantize Flux Kontext model on the fly. Please take a look at the codes below. I'm willing to integrate the parts you confirm.

Bug Fix:

  • linear_fns.py: I changed alpha from float to Tensor since Qutlass expects a tensor as the alpha parameter.
@torch.library.custom_op("fp_quant::matmul_mxf4_bf16_tn_op", mutates_args=())
def matmul_mxf4_bf16_tn_op(
    x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: torch.Tensor
) -> torch.Tensor:
    return matmul_mxf4_bf16_tn(
        x, w, to_blocked(xs), to_blocked(ws).view(torch.float8_e8m0fnu), alpha
    )

.
.
.
       # Quantize input
        x_flat_q, x_flat_scales, x_flat_mask = forward_quantize(
            x_flat, forward_hadamard_matrix, dtype, forward_method
        )

        y = forward_gemm(x_flat_q, weight_q, x_flat_scales, weight_scales, torch.tensor([1.0 / 9.0], device=x.device))

        y = y.unflatten(dim=0, sizes=x.shape[:-1])
        if bias is not None:
            y += bias
.
.
.

Replace

  • replace.py: It now is a dynamic function which can take a model on cuda and quantize it in fp4. Handles copying & quantizing & deleting weights. It also applies pre_forward method optionally that is required for quantization to take place.
                new = FPQuantLinear(in_features, out_features, config=fp_quant_linear_config, bias=module.bias is not None, device=module.weight.device, dtype=module.weight.dtype, name=current_key_name_str, enable_analytics=enable_analytics)
                with torch.no_grad():
                    if hasattr(new, "load_from_linear"):
                        new.load_from_linear(module)  # hypothetical helper
                    else:
                        # fallback if FPQuantLinear stores real-valued weights
                        if hasattr(new, "weight") and hasattr(module, "weight"):
                            new.weight.copy_(module.weight)
                        if hasattr(new, "bias") and hasattr(module, "bias") and module.bias is not None:
                            new.bias.copy_(module.bias)
                
                    model._modules[name] = new
                    
                    module.weight.to(device='cpu')
                    if module.bias is not None:
                        module.bias.to(device='cpu')

                    has_been_replaced = True

                    # Store the module class in case we need to transpose the weight later
                    model._modules[name].source_cls = type(module)
                    # Force requires grad to False to avoid unexpected errors
                    model._modules[name].requires_grad_(False)

                    # Force delete the tensors here
                    if hasattr(module, "weight") and module.weight is not None:
                        num_of_params += module.weight.numel()
                        del module.weight
                        module._parameters.pop("weight", None)
                        module.register_parameter("weight", None)

                    if hasattr(module, "bias") and module.bias is not None:
                        del module.bias
                        module._parameters.pop("bias", None)
                        module.register_parameter("bias", None)

                    del module  

                    if apply_pre_forward:
                        model._modules[name].pre_forward()

                torch.cuda.empty_cache()

Layer Analytics

  • layer_analytics.py: This is a file for analyzing runtime analysis on the quantized layers. Each quantized layer automatically calls analyze method and registers itself to the layer analytics. After all quantized layers are registered, we measure the time that it takes to run quantized version vs normal version and save these results to a .json file. When restart the program, we consider these stored times and determine which layer to quantize or not.
    # Now for each unique layer
    for key in unique_layers:
        input_shape, in_features, out_features, device, dtype, bias = key
        config = unique_layers[key]["config"]

        device = torch.device(device)
        dtype = torch.__dict__[dtype.split('.')[-1]]
        nn_layer = torch.nn.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
        from .linear import FPQuantLinear
        quantized_layer = FPQuantLinear(in_features, out_features, bias=bias, config=config, device=device, dtype=dtype)
        quantized_layer.pre_forward()

        with torch.no_grad():
            sample_input = torch.randn(*input_shape, device=device, dtype=dtype)

            def b1():
                return nn_layer(sample_input)
            
            def b2():
                return quantized_layer(sample_input)

            if input_shape[0] == 12288 or input_shape[1] == 12288:
                print(f"AAAAAAAAAAAAAAAA")

            quantized_layer_time, _ = bench_ms(b2, warmup=100, iters=200)
            nn_layer_time, _ = bench_ms(b1, warmup=100, iters=200)

            if input_shape[0] == 12288 or input_shape[1] == 12288:
                print(f"AAAAAAAAAAAAAAAA")
            
            print(f"Layer: {key}, nn_layer_time: {nn_layer_time:.3f} ms, quantized_layer_time: {quantized_layer_time:.3f} ms, ratio: {quantized_layer_time / nn_layer_time:.3f}")

            unique_layers[key]["quantized_layer_time"] = quantized_layer_time
            unique_layers[key]["nn_layer_time"] = nn_layer_time
        
        del quantized_layer
        del nn_layer
        torch.cuda.empty_cache()

Quantizing Flux Kontext

  • Readme.md:

  • Here is time measurements to run the quantized model on my RTX 5090 and a quick example to use FP-Quant to quantize the models on the fly.

Runtime per step (bf16): ~790ms*
Runtime per step (partially quantized ["quantized_layer_time"]/layer_analytics_list[key]["nn_layer_time"] > 0.95] ): ~410ms (without cuda sync and cuda event, it is measured ~353ms)
Runtime per step (fully quantized): ~403ms
Here is the Nvidia's TensorRT fp4:
image

This is not a fair comparison without knowing the details of the machine and gpu clocks etc. I think this is an acceptable result if you also consider the fact that TensorRT is a highly optimized engine.

pipe = FluxKontextPipeline.from_pretrained("/home/cropy/flux_kontext", 
                                        local_files_only=True,
                                        quantization_config=pipeline_quant_config,
                                        torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Apply Qutlass quantization to the transformer
# Read the layer analytics (if present) to compare each layer’s quantized runtime with the normal runtime.

try:
    with open("layer_analytics.json", "r") as f:
        layer_analytics_list = json.load(f)
        layer_analytics_list = [key for key in layer_analytics_list if layer_analytics_list[key]["quantized_layer_time"]/layer_analytics_list[key]["nn_layer_time"] > 0.95]
    enable_analytics = False
except:
    layer_analytics_list = []
    print("No layer_analytics.json found, or error in reading it.")
    enable_analytics = True

from fp_quant.inference_lib.src.fp_quant import FPQuantLinear, FPQuantConfig, FPQuantDtype, replace_with_fp_quant_linear
fp_quant_config = FPQuantConfig(forward_dtype=FPQuantDtype.MXFP4, forward_method="abs_max", 
                                backward_dtype=FPQuantDtype.BF16, hadamard_group_size=32,
                                modules_to_not_convert=[
                                    "x_embedder", # we should not quantize x_embedder. Otherwise the resulting image looks like noise.
                                    *layer_analytics_list
                                ],
)

_, result, num_of_params=replace_with_fp_quant_linear(pipe.transformer, fp_quant_config, apply_pre_forward=True, enable_analytics=enable_analytics)
print("Transformer Replaced:", result, "Num of params:", num_of_params)

pipe.transformer = torch.compile(
    pipe.transformer, mode="max-autotune", fullgraph=False
)
pipe.vae.decode = torch.compile(
    pipe.vae.decode, mode="max-autotune", fullgraph=True
)

file_name = "images/1.png"

input_image = load_image(file_name)

height=1024
width=1024
input_image.resize((width, height))
num_images_per_prompt = 1


for _ in range(5):
    images = pipe(
    image=input_image,
    prompt="Your prompt!",
    height=height, 
    width=width,
    max_area=height*width,
    num_inference_steps=25,
    generator=torch.manual_seed(441),
    num_images_per_prompt=num_images_per_prompt
    ).images


    for i in range(num_images_per_prompt):
        file_name_i = file_name.replace(".jpg", f"_{i}.jpg").replace(".png", f"_{i}.png")
        images[i].save(file_name_i)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant