Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
See: [https://github.com/IST-DASLab/FP-Quant/pull/6](https://github.com/IST-DASLab/FP-Quant/pull/6)

# FP Format Quantization Harness

This is a harness for efficient and accurate weight-and-activation quantization for low-bit FP/INT formats, with and without microscaling, including FP4, NVFP4, and MXFP. These formats are compatible with the NVIDIA Blackwell GPU architecture.
Expand Down
89 changes: 75 additions & 14 deletions inference_lib/README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,82 @@
# fp_quant
# Quantizing Flux Kontext

A library that wraps [`qutlass`](https://github.com/IST-DASLab/qutlass) kernels with linear layer wrappers for integrations into training and inference engines.
- Here is a quick example for how to use FP-Quant to quantize the models on the fly.

## Installation
~~~python

```bash
pip install .
```
pipe = FluxKontextPipeline.from_pretrained("/home/cropy/flux_kontext",
local_files_only=True,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16)

## Usage
pipe.to("cuda")

```python
from fp_quant import replace_with_fp_quant_linear, FPQuantConfig
# Apply Qutlass quantization to the transformer
try:

# Replace nn.Linear layers with fp_quant.FPQuantLinear
replace_with_fp_quant_linear(
model,
fp_quant_linear_config=FPQuantConfig(),
# read layer_analytics.json
with open("layer_analytics.json", "r") as f:
layer_analytics_list = json.load(f)
layer_analytics_list = [key for key in layer_analytics_list if layer_analytics_list[key]["quantized_layer_time"]/layer_analytics_list[key]["nn_layer_time"] > 0.95]
enable_analytics = False
except:
layer_analytics_list = []
print("No layer_analytics.json found, or error in reading it.")
enable_analytics = True

from fp_quant.inference_lib.src.fp_quant import FPQuantLinear, FPQuantConfig, FPQuantDtype, replace_with_fp_quant_linear
fp_quant_config = FPQuantConfig(forward_dtype=FPQuantDtype.MXFP4, forward_method="abs_max",
backward_dtype=FPQuantDtype.BF16, hadamard_group_size=32,
modules_to_not_convert=[
# "pos_embed",
# "text_time_guidance_cls",
# "time_text_embed",
# "transformer_blocks",
# "single_transformer_blocks",
# "proj_out",
# "norm_out",
"x_embedder",
# "context_embedder",
*layer_analytics_list
],
)

_, result, num_of_params=replace_with_fp_quant_linear(pipe.transformer, fp_quant_config, apply_pre_forward=True, enable_analytics=enable_analytics)
print("Transformer Replaced:", result, "Num of params:", num_of_params)

pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=False
)
```
pipe.vae.decode = torch.compile(
pipe.vae.decode, mode="max-autotune", fullgraph=True
)

file_name = "images/1.png"

input_image = load_image(file_name)

height=1024
width=1024
input_image.resize((width, height))
num_images_per_prompt = 1


for _ in range(5):
images = pipe(
image=input_image,
prompt="Your prompt!",
height=height,
width=width,
max_area=height*width,
num_inference_steps=25,
generator=torch.manual_seed(441),
num_images_per_prompt=num_images_per_prompt
).images


for i in range(num_images_per_prompt):
file_name_i = file_name.replace(".jpg", f"_{i}.jpg").replace(".png", f"_{i}.png")
images[i].save(file_name_i)


~~~
133 changes: 133 additions & 0 deletions inference_lib/src/fp_quant/module/layer_analytics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import json

from ..utils.config import FPQuantConfig, FPQuantDtype


all_added_layer_names = []

layer_list = {

}

def bench_ms(fn, warmup=10, iters=100):
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)
for _ in range(warmup):
_ = fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
start_evt.record()
_ = fn()
end_evt.record()
torch.cuda.synchronize()
times.append(start_evt.elapsed_time(end_evt)) # ms
t = torch.tensor(times, device="cpu")
return float(t.mean().item()), float(t.std(unbiased=False).item())

def add_layer(layer: torch.nn.Module, layer_name: str, input_shape: torch.Size, in_features: int, out_features: int, device, dtype):

global layer_list

if layer_name in layer_list:
return

layer_info = {
"config": layer.config,
"bias": layer.bias is not None,
"layer_name": layer_name,
"input_shape": list(input_shape),
"in_features": in_features,
"out_features": out_features,
"device": str(device),
"dtype": str(dtype),
}

layer_list[layer_name] = layer_info

print(f"Layer name: {layer_name}, input shape: {input_shape}, in_features: {in_features}, out_features: {out_features}")

if len(all_added_layer_names) == len(layer_list):
analyze_layers()
print("All layers have been analyzed.")

def analyze_layers():

# First find all unique pairs of (input_shape, in_features, out_features, device, dtype)

global layer_list

unique_layers = {}

for name in layer_list:
bias = layer_list[name]["bias"]
input_shape = tuple(layer_list[name]["input_shape"])
in_features = layer_list[name]["in_features"]
out_features = layer_list[name]["out_features"]
device = layer_list[name]["device"]
dtype = layer_list[name]["dtype"]
config = layer_list[name]["config"]

key = (input_shape, in_features, out_features, device, dtype, bias)

# Only keep the first layer encountered for each unique key
if key not in unique_layers:
unique_layers[key] = {
"config": config
}

# Now for each unique layer
for key in unique_layers:
input_shape, in_features, out_features, device, dtype, bias = key
config = unique_layers[key]["config"]

device = torch.device(device)
dtype = torch.__dict__[dtype.split('.')[-1]]
nn_layer = torch.nn.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
from .linear import FPQuantLinear
quantized_layer = FPQuantLinear(in_features, out_features, bias=bias, config=config, device=device, dtype=dtype)
quantized_layer.pre_forward()

with torch.no_grad():
sample_input = torch.randn(*input_shape, device=device, dtype=dtype)

def b1():
return nn_layer(sample_input)

def b2():
return quantized_layer(sample_input)

if input_shape[0] == 12288 or input_shape[1] == 12288:
print(f"AAAAAAAAAAAAAAAA")

quantized_layer_time, _ = bench_ms(b2, warmup=100, iters=200)
nn_layer_time, _ = bench_ms(b1, warmup=100, iters=200)

if input_shape[0] == 12288 or input_shape[1] == 12288:
print(f"AAAAAAAAAAAAAAAA")

print(f"Layer: {key}, nn_layer_time: {nn_layer_time:.3f} ms, quantized_layer_time: {quantized_layer_time:.3f} ms, ratio: {quantized_layer_time / nn_layer_time:.3f}")

unique_layers[key]["quantized_layer_time"] = quantized_layer_time
unique_layers[key]["nn_layer_time"] = nn_layer_time

del quantized_layer
del nn_layer
torch.cuda.empty_cache()

for name in layer_list:
input_shape = tuple(layer_list[name]["input_shape"])
in_features = layer_list[name]["in_features"]
out_features = layer_list[name]["out_features"]
device = layer_list[name]["device"]
dtype = layer_list[name]["dtype"]
key = (input_shape, in_features, out_features, device, dtype, bias)
layer_list[name]["quantized_layer_time"] = unique_layers[key]["quantized_layer_time"]
layer_list[name]["nn_layer_time"] = unique_layers[key]["nn_layer_time"]
del layer_list[name]["config"]

# Save to file
with open("layer_analytics.json", "w") as f:
json.dump(layer_list, f, indent=4)

53 changes: 46 additions & 7 deletions inference_lib/src/fp_quant/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from scipy.linalg import hadamard

from ..utils import FPQuantConfig, FPQuantDtype

from . import layer_analytics

from .linear_fns import (
HAS_QUTLASS,
FPQuant4x16MasterFn,
Expand All @@ -23,6 +26,17 @@ def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.devic
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device
)

def free_param(param: torch.nn.Parameter):
if param is not None:
# Move off GPU to free CUDA memory
if param.device.type == "cuda":
param.data = param.data.cpu()

# Break the computational graph and drop storage reference
param.detach_()

# Remove the parameter reference
del param

class FPQuantLinear(nn.Module):
def __init__(
Expand All @@ -33,6 +47,8 @@ def __init__(
bias: bool = True,
device: torch.device = None,
dtype: torch.dtype = None,
name: str = None,
enable_analytics: bool = False,
):
super().__init__()

Expand All @@ -42,14 +58,19 @@ def __init__(
)

factory_kwargs = {"device": device, "dtype": dtype}
self.device = device
self.dtype = dtype

self.in_features = in_features
self.out_features = out_features
self.name = name
self.name_analyzed = not enable_analytics

self.weight = nn.Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
self.dqweight = nn.Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
self.dqweight = None

if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
Expand Down Expand Up @@ -159,18 +180,33 @@ def pre_forward(self):
self.scales = nn.Parameter(
scales.view(dtype=torch.uint8), requires_grad=False
)

if self.weight is not None:
free_param(self.weight)
self.register_parameter("weight", None)
torch.cuda.empty_cache()
del self.weight

self.weight = None
self.dqweight = None

def forward(self, x) -> torch.Tensor:

if self.name is not None and not self.name_analyzed:
self.name_analyzed = True
layer_analytics.add_layer(self, self.name, x.shape, self.in_features, self.out_features, x.device, x.dtype)


result = None

match (
self.config.forward_dtype,
self.config.backward_dtype,
self.config.store_master_weights,
self.config.pseudoquantization,
):
case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, True, False):
return FPQuant4x16MasterFn.apply(
result = FPQuant4x16MasterFn.apply(
x,
self.weight,
self.bias,
Expand All @@ -179,7 +215,7 @@ def forward(self, x) -> torch.Tensor:
self.config.forward_method,
)
case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, False, False):
return FPQuant4x16NoMasterFn.apply(
result = FPQuant4x16NoMasterFn.apply(
x,
self.qweight,
self.scales,
Expand All @@ -189,7 +225,7 @@ def forward(self, x) -> torch.Tensor:
self.config.forward_method,
)
case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, True, True):
return PseudoQuant4x16MasterFn.apply(
result = PseudoQuant4x16MasterFn.apply(
x,
self.dqweight,
self.bias,
Expand All @@ -198,7 +234,7 @@ def forward(self, x) -> torch.Tensor:
self.config.forward_method,
)
case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, False, True):
return PseudoQuant4x16NoMasterFn.apply(
result = PseudoQuant4x16NoMasterFn.apply(
x,
self.dqweight,
self.bias,
Expand All @@ -210,3 +246,6 @@ def forward(self, x) -> torch.Tensor:
raise ValueError(
f"Forward dtype: {self.config.forward_dtype}, backward dtype: {self.config.backward_dtype}, store_master_weights: {self.config.store_master_weights}, pseudoquantization: {self.config.pseudoquantization} isn't supported yet."
)

return result

6 changes: 3 additions & 3 deletions inference_lib/src/fp_quant/module/linear_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _(x_flat, hadamard_matrix, forward_method):

@torch.library.custom_op("fp_quant::matmul_mxf4_bf16_tn_op", mutates_args=())
def matmul_mxf4_bf16_tn_op(
x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: float
x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: torch.Tensor
) -> torch.Tensor:
return matmul_mxf4_bf16_tn(
x, w, to_blocked(xs), to_blocked(ws).view(torch.float8_e8m0fnu), alpha
Expand All @@ -54,7 +54,7 @@ def _(x, w, xs, ws, alpha):

@torch.library.custom_op("fp_quant::matmul_ada_mxf4_bf16_tn_op", mutates_args=())
def matmul_ada_mxf4_bf16_tn_op(
x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: float
x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: torch.Tensor
) -> torch.Tensor:
return matmul_ada_mxf4_bf16_tn(x, w, xs, ws.view(torch.float8_e8m0fnu), alpha)

Expand Down Expand Up @@ -248,7 +248,7 @@ def forward(
x_flat, forward_hadamard_matrix, dtype, forward_method
)

y = forward_gemm(x_flat_q, weight_q, x_flat_scales, weight_scales, 1.0 / 9.0)
y = forward_gemm(x_flat_q, weight_q, x_flat_scales, weight_scales, torch.tensor([1.0 / 9.0], device=x.device))

y = y.unflatten(dim=0, sizes=x.shape[:-1])
if bias is not None:
Expand Down
Loading