Skip to content

Commit 4ca3d84

Browse files
authored
Support for Chroma - Flux1 Schnell distilled with CFG (Comfy-Org#7355)
* Upload files for Chroma Implementation * Remove trailing whitespace * trim more trailing whitespace..oops * remove unused imports * Add supported_inference_dtypes * Set min_length to 0 and remove attention_mask=True * Set min_length to 1 * get_mdulations added from blepping and minor changes * Add lora conversion if statement in lora.py * Update supported_models.py * update model_base.py * add uptream commits * set modelType.FLOW, will cause beta scheduler to work properly * Adjust memory usage factor and remove unnecessary code * fix mistake * reduce code duplication * remove unused imports * refactor for upstream sync * sync chroma-support with upstream via syncbranch patch * Update sd.py * Add Chroma as option for the OptimalStepsScheduler node
1 parent 39c27a3 commit 4ca3d84

File tree

11 files changed

+667
-4
lines changed

11 files changed

+667
-4
lines changed

comfy/ldm/chroma/layers.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import torch
2+
from torch import Tensor, nn
3+
4+
from .math import attention
5+
from comfy.ldm.flux.layers import (
6+
MLPEmbedder,
7+
RMSNorm,
8+
QKNorm,
9+
SelfAttention,
10+
ModulationOut,
11+
)
12+
13+
14+
15+
class ChromaModulationOut(ModulationOut):
16+
@classmethod
17+
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
18+
return cls(
19+
shift=tensor[:, offset : offset + 1, :],
20+
scale=tensor[:, offset + 1 : offset + 2, :],
21+
gate=tensor[:, offset + 2 : offset + 3, :],
22+
)
23+
24+
25+
26+
27+
class Approximator(nn.Module):
28+
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
29+
super().__init__()
30+
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
31+
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
32+
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
33+
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
34+
35+
@property
36+
def device(self):
37+
# Get the device of the module (assumes all parameters are on the same device)
38+
return next(self.parameters()).device
39+
40+
def forward(self, x: Tensor) -> Tensor:
41+
x = self.in_proj(x)
42+
43+
for layer, norms in zip(self.layers, self.norms):
44+
x = x + layer(norms(x))
45+
46+
x = self.out_proj(x)
47+
48+
return x
49+
50+
51+
class DoubleStreamBlock(nn.Module):
52+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
53+
super().__init__()
54+
55+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
56+
self.num_heads = num_heads
57+
self.hidden_size = hidden_size
58+
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
59+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
60+
61+
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
62+
self.img_mlp = nn.Sequential(
63+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
64+
nn.GELU(approximate="tanh"),
65+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
66+
)
67+
68+
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
69+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
70+
71+
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
72+
self.txt_mlp = nn.Sequential(
73+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
74+
nn.GELU(approximate="tanh"),
75+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
76+
)
77+
self.flipped_img_txt = flipped_img_txt
78+
79+
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
80+
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
81+
82+
# prepare image for attention
83+
img_modulated = self.img_norm1(img)
84+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
85+
img_qkv = self.img_attn.qkv(img_modulated)
86+
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
87+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
88+
89+
# prepare txt for attention
90+
txt_modulated = self.txt_norm1(txt)
91+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
92+
txt_qkv = self.txt_attn.qkv(txt_modulated)
93+
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
94+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
95+
96+
# run actual attention
97+
attn = attention(torch.cat((txt_q, img_q), dim=2),
98+
torch.cat((txt_k, img_k), dim=2),
99+
torch.cat((txt_v, img_v), dim=2),
100+
pe=pe, mask=attn_mask)
101+
102+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
103+
104+
# calculate the img bloks
105+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
106+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
107+
108+
# calculate the txt bloks
109+
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
110+
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
111+
112+
if txt.dtype == torch.float16:
113+
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
114+
115+
return img, txt
116+
117+
118+
class SingleStreamBlock(nn.Module):
119+
"""
120+
A DiT block with parallel linear layers as described in
121+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
122+
"""
123+
124+
def __init__(
125+
self,
126+
hidden_size: int,
127+
num_heads: int,
128+
mlp_ratio: float = 4.0,
129+
qk_scale: float = None,
130+
dtype=None,
131+
device=None,
132+
operations=None
133+
):
134+
super().__init__()
135+
self.hidden_dim = hidden_size
136+
self.num_heads = num_heads
137+
head_dim = hidden_size // num_heads
138+
self.scale = qk_scale or head_dim**-0.5
139+
140+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
141+
# qkv and mlp_in
142+
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
143+
# proj and mlp_out
144+
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
145+
146+
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
147+
148+
self.hidden_size = hidden_size
149+
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
150+
151+
self.mlp_act = nn.GELU(approximate="tanh")
152+
153+
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
154+
mod = vec
155+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
156+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
157+
158+
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
159+
q, k = self.norm(q, k, v)
160+
161+
# compute attention
162+
attn = attention(q, k, v, pe=pe, mask=attn_mask)
163+
# compute activation in mlp stream, cat again and run second linear layer
164+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
165+
x += mod.gate * output
166+
if x.dtype == torch.float16:
167+
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
168+
return x
169+
170+
171+
class LastLayer(nn.Module):
172+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
173+
super().__init__()
174+
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
175+
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
176+
177+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
178+
shift, scale = vec
179+
shift = shift.squeeze(1)
180+
scale = scale.squeeze(1)
181+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
182+
x = self.linear(x)
183+
return x

comfy/ldm/chroma/math.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from einops import rearrange
3+
from torch import Tensor
4+
5+
from comfy.ldm.modules.attention import optimized_attention
6+
import comfy.model_management
7+
8+
9+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
10+
q_shape = q.shape
11+
k_shape = k.shape
12+
13+
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
14+
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
15+
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
16+
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
17+
18+
heads = q.shape[1]
19+
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
20+
return x
21+
22+
23+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
24+
assert dim % 2 == 0
25+
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
26+
device = torch.device("cpu")
27+
else:
28+
device = pos.device
29+
30+
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
31+
omega = 1.0 / (theta**scale)
32+
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
33+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
34+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
35+
return out.to(dtype=torch.float32, device=pos.device)
36+
37+
38+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
39+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
40+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
41+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
42+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
43+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
44+

0 commit comments

Comments
 (0)