11import torch
22from comfy .ldm .modules .attention import optimized_attention_for_device
33import comfy .ops
4+ import math
45
56def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ], crop = True ):
67 image = image [:, :, :, :3 ] if image .shape [3 ] > 3 else image
@@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
2122 image = torch .clip ((255. * image ), 0 , 255 ).round () / 255.0
2223 return (image - mean .view ([3 ,1 ,1 ])) / std .view ([3 ,1 ,1 ])
2324
25+ def siglip2_flex_calc_resolution (oh , ow , patch_size , max_num_patches , eps = 1e-5 ):
26+ def scale_dim (size , scale ):
27+ scaled = math .ceil (size * scale / patch_size ) * patch_size
28+ return max (patch_size , int (scaled ))
29+
30+ # Binary search for optimal scale
31+ lo , hi = eps / 10 , 100.0
32+ while hi - lo >= eps :
33+ mid = (lo + hi ) / 2
34+ h , w = scale_dim (oh , mid ), scale_dim (ow , mid )
35+ if (h // patch_size ) * (w // patch_size ) <= max_num_patches :
36+ lo = mid
37+ else :
38+ hi = mid
39+
40+ return scale_dim (oh , lo ), scale_dim (ow , lo )
41+
42+ def siglip2_preprocess (image , size , patch_size , num_patches , mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ], crop = True ):
43+ if size > 0 :
44+ return clip_preprocess (image , size = size , mean = mean , std = std , crop = crop )
45+
46+ image = image [:, :, :, :3 ] if image .shape [3 ] > 3 else image
47+ mean = torch .tensor (mean , device = image .device , dtype = image .dtype )
48+ std = torch .tensor (std , device = image .device , dtype = image .dtype )
49+ image = image .movedim (- 1 , 1 )
50+
51+ b , c , h , w = image .shape
52+ h , w = siglip2_flex_calc_resolution (h , w , patch_size , num_patches )
53+
54+ image = torch .nn .functional .interpolate (image , size = (h , w ), mode = "bilinear" , antialias = True )
55+ image = torch .clip ((255. * image ), 0 , 255 ).round () / 255.0
56+ return (image - mean .view ([3 , 1 , 1 ])) / std .view ([3 , 1 , 1 ])
57+
2458class CLIPAttention (torch .nn .Module ):
2559 def __init__ (self , embed_dim , heads , dtype , device , operations ):
2660 super ().__init__ ()
@@ -175,6 +209,27 @@ def forward(self, *args, **kwargs):
175209 out = self .text_projection (x [2 ])
176210 return (x [0 ], x [1 ], out , x [2 ])
177211
212+ def siglip2_pos_embed (embed_weight , embeds , orig_shape ):
213+ embed_weight_len = round (embed_weight .shape [0 ] ** 0.5 )
214+ embed_weight = comfy .ops .cast_to_input (embed_weight , embeds ).movedim (1 , 0 ).reshape (1 , - 1 , embed_weight_len , embed_weight_len )
215+ embed_weight = torch .nn .functional .interpolate (embed_weight , size = orig_shape , mode = "bilinear" , align_corners = False , antialias = True )
216+ embed_weight = embed_weight .reshape (- 1 , embed_weight .shape [- 2 ] * embed_weight .shape [- 1 ]).movedim (0 , 1 )
217+ return embeds + embed_weight
218+
219+ class Siglip2Embeddings (torch .nn .Module ):
220+ def __init__ (self , embed_dim , num_channels = 3 , patch_size = 14 , image_size = 224 , model_type = "" , num_patches = None , dtype = None , device = None , operations = None ):
221+ super ().__init__ ()
222+ self .patch_embedding = operations .Linear (num_channels * patch_size * patch_size , embed_dim , dtype = dtype , device = device )
223+ self .position_embedding = operations .Embedding (num_patches , embed_dim , dtype = dtype , device = device )
224+ self .patch_size = patch_size
225+
226+ def forward (self , pixel_values ):
227+ b , c , h , w = pixel_values .shape
228+ img = pixel_values .movedim (1 , - 1 ).reshape (b , h // self .patch_size , self .patch_size , w // self .patch_size , self .patch_size , c )
229+ img = img .permute (0 , 1 , 3 , 2 , 4 , 5 )
230+ img = img .reshape (b , img .shape [1 ] * img .shape [2 ], - 1 )
231+ img = self .patch_embedding (img )
232+ return siglip2_pos_embed (self .position_embedding .weight , img , (h // self .patch_size , w // self .patch_size ))
178233
179234class CLIPVisionEmbeddings (torch .nn .Module ):
180235 def __init__ (self , embed_dim , num_channels = 3 , patch_size = 14 , image_size = 224 , model_type = "" , dtype = None , device = None , operations = None ):
@@ -218,8 +273,11 @@ def __init__(self, config_dict, dtype, device, operations):
218273 intermediate_activation = config_dict ["hidden_act" ]
219274 model_type = config_dict ["model_type" ]
220275
221- self .embeddings = CLIPVisionEmbeddings (embed_dim , config_dict ["num_channels" ], config_dict ["patch_size" ], config_dict ["image_size" ], model_type = model_type , dtype = dtype , device = device , operations = operations )
222- if model_type == "siglip_vision_model" :
276+ if model_type in ["siglip2_vision_model" ]:
277+ self .embeddings = Siglip2Embeddings (embed_dim , config_dict ["num_channels" ], config_dict ["patch_size" ], config_dict ["image_size" ], model_type = model_type , num_patches = config_dict .get ("num_patches" , None ), dtype = dtype , device = device , operations = operations )
278+ else :
279+ self .embeddings = CLIPVisionEmbeddings (embed_dim , config_dict ["num_channels" ], config_dict ["patch_size" ], config_dict ["image_size" ], model_type = model_type , dtype = dtype , device = device , operations = operations )
280+ if model_type in ["siglip_vision_model" , "siglip2_vision_model" ]:
223281 self .pre_layrnorm = lambda a : a
224282 self .output_layernorm = True
225283 else :
0 commit comments