1313from diffusers .image_processor import VaeImageProcessor
1414from diffusers .loaders import FromSingleFileMixin , StableDiffusionLoraLoaderMixin , TextualInversionLoaderMixin
1515from diffusers .models import AutoencoderKL , UNet2DConditionModel
16+ from diffusers .models .lora import adjust_lora_scale_text_encoder
1617from diffusers .pipelines .pipeline_utils import StableDiffusionMixin
1718from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput , StableDiffusionSafetyChecker
1819from diffusers .schedulers import KarrasDiffusionSchedulers
1920from diffusers .utils import (
2021 PIL_INTERPOLATION ,
22+ USE_PEFT_BACKEND ,
2123 deprecate ,
2224 logging ,
25+ scale_lora_layers ,
26+ unscale_lora_layers ,
2327)
2428from diffusers .utils .torch_utils import randn_tensor
2529
@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
199203 text_input : torch .Tensor ,
200204 chunk_length : int ,
201205 no_boseos_middle : Optional [bool ] = True ,
206+ clip_skip : Optional [int ] = None ,
202207):
203208 """
204209 When the length of tokens is a multiple of the capacity of the text encoder,
@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
214219 # cover the head and the tail by the starting and the ending tokens
215220 text_input_chunk [:, 0 ] = text_input [0 , 0 ]
216221 text_input_chunk [:, - 1 ] = text_input [0 , - 1 ]
217- text_embedding = pipe .text_encoder (text_input_chunk )[0 ]
222+ if clip_skip is None :
223+ prompt_embeds = pipe .text_encoder (text_input_chunk .to (pipe .device ))
224+ text_embedding = prompt_embeds [0 ]
225+ else :
226+ prompt_embeds = pipe .text_encoder (text_input_chunk .to (pipe .device ), output_hidden_states = True )
227+ # Access the `hidden_states` first, that contains a tuple of
228+ # all the hidden states from the encoder layers. Then index into
229+ # the tuple to access the hidden states from the desired layer.
230+ prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 1 )]
231+ # We also need to apply the final LayerNorm here to not mess with the
232+ # representations. The `last_hidden_states` that we typically use for
233+ # obtaining the final prompt representations passes through the LayerNorm
234+ # layer.
235+ text_embedding = pipe .text_encoder .text_model .final_layer_norm (prompt_embeds )
218236
219237 if no_boseos_middle :
220238 if i == 0 :
@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
230248 text_embeddings .append (text_embedding )
231249 text_embeddings = torch .concat (text_embeddings , axis = 1 )
232250 else :
233- text_embeddings = pipe .text_encoder (text_input )[0 ]
251+ if clip_skip is None :
252+ clip_skip = 0
253+ prompt_embeds = pipe .text_encoder (text_input , output_hidden_states = True )[- 1 ][- (clip_skip + 1 )]
254+ text_embeddings = pipe .text_encoder .text_model .final_layer_norm (prompt_embeds )
234255 return text_embeddings
235256
236257
@@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
242263 no_boseos_middle : Optional [bool ] = False ,
243264 skip_parsing : Optional [bool ] = False ,
244265 skip_weighting : Optional [bool ] = False ,
266+ clip_skip = None ,
267+ lora_scale = None ,
245268):
246269 r"""
247270 Prompts can be assigned with local weights using brackets. For example,
@@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
268291 skip_weighting (`bool`, *optional*, defaults to `False`):
269292 Skip the weighting. When the parsing is skipped, it is forced True.
270293 """
294+ # set lora scale so that monkey patched LoRA
295+ # function of text encoder can correctly access it
296+ if lora_scale is not None and isinstance (pipe , StableDiffusionLoraLoaderMixin ):
297+ pipe ._lora_scale = lora_scale
298+
299+ # dynamically adjust the LoRA scale
300+ if not USE_PEFT_BACKEND :
301+ adjust_lora_scale_text_encoder (pipe .text_encoder , lora_scale )
302+ else :
303+ scale_lora_layers (pipe .text_encoder , lora_scale )
271304 max_length = (pipe .tokenizer .model_max_length - 2 ) * max_embeddings_multiples + 2
272305 if isinstance (prompt , str ):
273306 prompt = [prompt ]
@@ -334,10 +367,7 @@ def get_weighted_text_embeddings(
334367
335368 # get the embeddings
336369 text_embeddings = get_unweighted_text_embeddings (
337- pipe ,
338- prompt_tokens ,
339- pipe .tokenizer .model_max_length ,
340- no_boseos_middle = no_boseos_middle ,
370+ pipe , prompt_tokens , pipe .tokenizer .model_max_length , no_boseos_middle = no_boseos_middle , clip_skip = clip_skip
341371 )
342372 prompt_weights = torch .tensor (prompt_weights , dtype = text_embeddings .dtype , device = text_embeddings .device )
343373 if uncond_prompt is not None :
@@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
346376 uncond_tokens ,
347377 pipe .tokenizer .model_max_length ,
348378 no_boseos_middle = no_boseos_middle ,
379+ clip_skip = clip_skip ,
349380 )
350381 uncond_weights = torch .tensor (uncond_weights , dtype = uncond_embeddings .dtype , device = uncond_embeddings .device )
351382
@@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
362393 current_mean = uncond_embeddings .float ().mean (axis = [- 2 , - 1 ]).to (uncond_embeddings .dtype )
363394 uncond_embeddings *= (previous_mean / current_mean ).unsqueeze (- 1 ).unsqueeze (- 1 )
364395
396+ if pipe .text_encoder is not None :
397+ if isinstance (pipe , StableDiffusionLoraLoaderMixin ) and USE_PEFT_BACKEND :
398+ # Retrieve the original scale by scaling back the LoRA layers
399+ unscale_lora_layers (pipe .text_encoder , lora_scale )
400+
365401 if uncond_prompt is not None :
366402 return text_embeddings , uncond_embeddings
367403 return text_embeddings , None
@@ -549,6 +585,8 @@ def _encode_prompt(
549585 max_embeddings_multiples = 3 ,
550586 prompt_embeds : Optional [torch .Tensor ] = None ,
551587 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
588+ clip_skip : Optional [int ] = None ,
589+ lora_scale : Optional [float ] = None ,
552590 ):
553591 r"""
554592 Encodes the prompt into text encoder hidden states.
@@ -597,6 +635,8 @@ def _encode_prompt(
597635 prompt = prompt ,
598636 uncond_prompt = negative_prompt if do_classifier_free_guidance else None ,
599637 max_embeddings_multiples = max_embeddings_multiples ,
638+ clip_skip = clip_skip ,
639+ lora_scale = lora_scale ,
600640 )
601641 if prompt_embeds is None :
602642 prompt_embeds = prompt_embeds1
@@ -790,6 +830,7 @@ def __call__(
790830 return_dict : bool = True ,
791831 callback : Optional [Callable [[int , int , torch .Tensor ], None ]] = None ,
792832 is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
833+ clip_skip : Optional [int ] = None ,
793834 callback_steps : int = 1 ,
794835 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
795836 ):
@@ -865,6 +906,9 @@ def __call__(
865906 is_cancelled_callback (`Callable`, *optional*):
866907 A function that will be called every `callback_steps` steps during inference. If the function returns
867908 `True`, the inference will be cancelled.
909+ clip_skip (`int`, *optional*):
910+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
911+ the output of the pre-final layer will be used for computing the prompt embeddings.
868912 callback_steps (`int`, *optional*, defaults to 1):
869913 The frequency at which the `callback` function will be called. If not specified, the callback will be
870914 called at every step.
@@ -903,6 +947,7 @@ def __call__(
903947 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
904948 # corresponds to doing no classifier free guidance.
905949 do_classifier_free_guidance = guidance_scale > 1.0
950+ lora_scale = cross_attention_kwargs .get ("scale" , None ) if cross_attention_kwargs is not None else None
906951
907952 # 3. Encode input prompt
908953 prompt_embeds = self ._encode_prompt (
@@ -914,6 +959,8 @@ def __call__(
914959 max_embeddings_multiples ,
915960 prompt_embeds = prompt_embeds ,
916961 negative_prompt_embeds = negative_prompt_embeds ,
962+ clip_skip = clip_skip ,
963+ lora_scale = lora_scale ,
917964 )
918965 dtype = prompt_embeds .dtype
919966
@@ -1044,6 +1091,7 @@ def text2img(
10441091 return_dict : bool = True ,
10451092 callback : Optional [Callable [[int , int , torch .Tensor ], None ]] = None ,
10461093 is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
1094+ clip_skip = None ,
10471095 callback_steps : int = 1 ,
10481096 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
10491097 ):
@@ -1101,6 +1149,9 @@ def text2img(
11011149 is_cancelled_callback (`Callable`, *optional*):
11021150 A function that will be called every `callback_steps` steps during inference. If the function returns
11031151 `True`, the inference will be cancelled.
1152+ clip_skip (`int`, *optional*):
1153+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1154+ the output of the pre-final layer will be used for computing the prompt embeddings.
11041155 callback_steps (`int`, *optional*, defaults to 1):
11051156 The frequency at which the `callback` function will be called. If not specified, the callback will be
11061157 called at every step.
@@ -1135,6 +1186,7 @@ def text2img(
11351186 return_dict = return_dict ,
11361187 callback = callback ,
11371188 is_cancelled_callback = is_cancelled_callback ,
1189+ clip_skip = clip_skip ,
11381190 callback_steps = callback_steps ,
11391191 cross_attention_kwargs = cross_attention_kwargs ,
11401192 )
0 commit comments