Skip to content

Commit b3c0e4d

Browse files
Make loras work on nvfp4 models. (Comfy-Org#11837)
The initial applying is a bit slow but will probably be sped up in the future.
1 parent ecaeeb9 commit b3c0e4d

File tree

4 files changed

+150
-4
lines changed

4 files changed

+150
-4
lines changed

comfy/float.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0):
6565
return output
6666

6767
return value.to(dtype=dtype)
68+
69+
70+
# TODO: improve this?
71+
def stochastic_float_to_fp4_e2m1(x, generator):
72+
sign = torch.signbit(x).to(torch.uint8)
73+
x_abs = x.abs()
74+
75+
exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3)
76+
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
77+
78+
x_abs = x.abs()
79+
exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3)
80+
81+
mantissa = torch.where(
82+
exp > 0,
83+
(x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0,
84+
(x_abs * 2.0)
85+
).round().to(torch.uint8)
86+
87+
fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa
88+
89+
fp4_flat = fp4.view(-1)
90+
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
91+
return packed.reshape(list(x.shape)[:-1] + [-1])
92+
93+
94+
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
95+
"""
96+
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
97+
See:
98+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
99+
100+
Args:
101+
input_matrix: Input tensor of shape (H, W)
102+
Returns:
103+
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
104+
"""
105+
106+
def ceil_div(a, b):
107+
return (a + b - 1) // b
108+
109+
rows, cols = input_matrix.shape
110+
n_row_blocks = ceil_div(rows, 128)
111+
n_col_blocks = ceil_div(cols, 4)
112+
113+
# Calculate the padded shape
114+
padded_rows = n_row_blocks * 128
115+
padded_cols = n_col_blocks * 4
116+
117+
padded = input_matrix
118+
if (rows, cols) != (padded_rows, padded_cols):
119+
padded = torch.zeros(
120+
(padded_rows, padded_cols),
121+
device=input_matrix.device,
122+
dtype=input_matrix.dtype,
123+
)
124+
padded[:rows, :cols] = input_matrix
125+
126+
# Rearrange the blocks
127+
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
128+
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
129+
if flatten:
130+
return rearranged.flatten()
131+
132+
return rearranged.reshape(padded_rows, padded_cols)
133+
134+
135+
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
136+
F4_E2M1_MAX = 6.0
137+
F8_E4M3_MAX = 448.0
138+
139+
def roundup(x: int, multiple: int) -> int:
140+
"""Round up x to the nearest multiple."""
141+
return ((x + multiple - 1) // multiple) * multiple
142+
143+
orig_shape = x.shape
144+
145+
# Handle padding
146+
if pad_16x:
147+
rows, cols = x.shape
148+
padded_rows = roundup(rows, 16)
149+
padded_cols = roundup(cols, 16)
150+
if padded_rows != rows or padded_cols != cols:
151+
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
152+
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
153+
# what we want to produce. If we pad here, we want the padded output.
154+
orig_shape = x.shape
155+
156+
block_size = 16
157+
158+
x = x.reshape(orig_shape[0], -1, block_size)
159+
max_abs = torch.amax(torch.abs(x), dim=-1)
160+
block_scale = max_abs / F4_E2M1_MAX
161+
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
162+
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
163+
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
164+
165+
# Handle zero blocks (from padding): avoid 0/0 NaN
166+
zero_scale_mask = (total_scale == 0)
167+
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
168+
169+
x = x / total_scale_safe.unsqueeze(-1)
170+
171+
generator = torch.Generator(device=x.device)
172+
generator.manual_seed(seed)
173+
174+
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
175+
176+
x = x.view(orig_shape)
177+
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
178+
179+
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
180+
return data_lp, blocked_scales

comfy/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
699699
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
700700
if getattr(self, 'layout_type', None) is not None:
701701
# dtype is now implicit in the layout class
702-
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
702+
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
703703
else:
704704
weight = weight.to(self.weight.dtype)
705705
if return_weight:

comfy/quant_ops.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
QuantizedTensor,
88
QuantizedLayout,
99
TensorCoreFP8Layout as _CKFp8Layout,
10-
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
10+
TensorCoreNVFP4Layout as _CKNvfp4Layout,
1111
register_layout_op,
1212
register_layout_class,
1313
get_layout_class,
@@ -34,7 +34,7 @@ class QuantizedTensor:
3434
class _CKFp8Layout:
3535
pass
3636

37-
class TensorCoreNVFP4Layout:
37+
class _CKNvfp4Layout:
3838
pass
3939

4040
def register_layout_class(name, cls):
@@ -84,6 +84,39 @@ def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
8484
return qdata, params
8585

8686

87+
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
88+
@classmethod
89+
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
90+
if tensor.dim() != 2:
91+
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
92+
93+
orig_dtype = tensor.dtype
94+
orig_shape = tuple(tensor.shape)
95+
96+
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
97+
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
98+
99+
if not isinstance(scale, torch.Tensor):
100+
scale = torch.tensor(scale)
101+
scale = scale.to(device=tensor.device, dtype=torch.float32)
102+
103+
padded_shape = cls.get_padded_shape(orig_shape)
104+
needs_padding = padded_shape != orig_shape
105+
106+
if stochastic_rounding > 0:
107+
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
108+
else:
109+
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
110+
111+
params = cls.Params(
112+
scale=scale,
113+
orig_dtype=orig_dtype,
114+
orig_shape=orig_shape,
115+
block_scale=block_scale,
116+
)
117+
return qdata, params
118+
119+
87120
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
88121
FP8_DTYPE = torch.float8_e4m3fn
89122

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ psutil
2121
alembic
2222
SQLAlchemy
2323
av>=14.2.0
24-
comfy-kitchen>=0.2.5
24+
comfy-kitchen>=0.2.6
2525

2626
#non essential dependencies:
2727
kornia>=0.7.1

0 commit comments

Comments
 (0)