Skip to content

Commit c881a1d

Browse files
Support the siglip 2 naflex model as a clip vision model. (Comfy-Org#11831)
Not useful yet.
1 parent a3b5d49 commit c881a1d

File tree

3 files changed

+91
-10
lines changed

3 files changed

+91
-10
lines changed

comfy/clip_model.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from comfy.ldm.modules.attention import optimized_attention_for_device
33
import comfy.ops
4+
import math
45

56
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
67
image = image[:, :, :, :3] if image.shape[3] > 3 else image
@@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
2122
image = torch.clip((255. * image), 0, 255).round() / 255.0
2223
return (image - mean.view([3,1,1])) / std.view([3,1,1])
2324

25+
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
26+
def scale_dim(size, scale):
27+
scaled = math.ceil(size * scale / patch_size) * patch_size
28+
return max(patch_size, int(scaled))
29+
30+
# Binary search for optimal scale
31+
lo, hi = eps / 10, 100.0
32+
while hi - lo >= eps:
33+
mid = (lo + hi) / 2
34+
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
35+
if (h // patch_size) * (w // patch_size) <= max_num_patches:
36+
lo = mid
37+
else:
38+
hi = mid
39+
40+
return scale_dim(oh, lo), scale_dim(ow, lo)
41+
42+
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
43+
if size > 0:
44+
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
45+
46+
image = image[:, :, :, :3] if image.shape[3] > 3 else image
47+
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
48+
std = torch.tensor(std, device=image.device, dtype=image.dtype)
49+
image = image.movedim(-1, 1)
50+
51+
b, c, h, w = image.shape
52+
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
53+
54+
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
55+
image = torch.clip((255. * image), 0, 255).round() / 255.0
56+
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
57+
2458
class CLIPAttention(torch.nn.Module):
2559
def __init__(self, embed_dim, heads, dtype, device, operations):
2660
super().__init__()
@@ -175,6 +209,27 @@ def forward(self, *args, **kwargs):
175209
out = self.text_projection(x[2])
176210
return (x[0], x[1], out, x[2])
177211

212+
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
213+
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
214+
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
215+
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
216+
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
217+
return embeds + embed_weight
218+
219+
class Siglip2Embeddings(torch.nn.Module):
220+
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
221+
super().__init__()
222+
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
223+
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
224+
self.patch_size = patch_size
225+
226+
def forward(self, pixel_values):
227+
b, c, h, w = pixel_values.shape
228+
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
229+
img = img.permute(0, 1, 3, 2, 4, 5)
230+
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
231+
img = self.patch_embedding(img)
232+
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
178233

179234
class CLIPVisionEmbeddings(torch.nn.Module):
180235
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
@@ -218,8 +273,11 @@ def __init__(self, config_dict, dtype, device, operations):
218273
intermediate_activation = config_dict["hidden_act"]
219274
model_type = config_dict["model_type"]
220275

221-
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
222-
if model_type == "siglip_vision_model":
276+
if model_type in ["siglip2_vision_model"]:
277+
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
278+
else:
279+
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
280+
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
223281
self.pre_layrnorm = lambda a: a
224282
self.output_layernorm = True
225283
else:

comfy/clip_vision.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __setitem__(self, key, item):
2121
IMAGE_ENCODERS = {
2222
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
2323
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
24+
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
2425
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
2526
}
2627

@@ -32,9 +33,10 @@ def __init__(self, json_config):
3233
self.image_size = config.get("image_size", 224)
3334
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
3435
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
35-
model_type = config.get("model_type", "clip_vision_model")
36-
model_class = IMAGE_ENCODERS.get(model_type)
37-
if model_type == "siglip_vision_model":
36+
self.model_type = config.get("model_type", "clip_vision_model")
37+
self.config = config.copy()
38+
model_class = IMAGE_ENCODERS.get(self.model_type)
39+
if self.model_type == "siglip_vision_model":
3840
self.return_all_hidden_states = True
3941
else:
4042
self.return_all_hidden_states = False
@@ -55,7 +57,10 @@ def get_sd(self):
5557

5658
def encode_image(self, image, crop=True):
5759
comfy.model_management.load_model_gpu(self.patcher)
58-
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
60+
if self.model_type == "siglip2_vision_model":
61+
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
62+
else:
63+
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
5964
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
6065

6166
outputs = Output()
@@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
107112
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
108113
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
109114
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
110-
if embed_shape == 729:
111-
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
112-
elif embed_shape == 1024:
113-
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
115+
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
116+
if len(patch_embedding_shape) == 2:
117+
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
118+
else:
119+
if embed_shape == 729:
120+
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
121+
elif embed_shape == 1024:
122+
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
114123
elif embed_shape == 577:
115124
if "multi_modal_projector.linear_1.bias" in sd:
116125
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"num_channels": 3,
3+
"hidden_act": "gelu_pytorch_tanh",
4+
"hidden_size": 1152,
5+
"image_size": -1,
6+
"intermediate_size": 4304,
7+
"model_type": "siglip2_vision_model",
8+
"num_attention_heads": 16,
9+
"num_hidden_layers": 27,
10+
"patch_size": 16,
11+
"num_patches": 256,
12+
"image_mean": [0.5, 0.5, 0.5],
13+
"image_std": [0.5, 0.5, 0.5]
14+
}

0 commit comments

Comments
 (0)