Skip to content

[modular] Refactor pipeline functions #10726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,24 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
ip_adapter_image=data.ip_adapter_image,
ip_adapter_image_embeds=None,
device=data.device,
num_images_per_prompt=1,
do_classifier_free_guidance=data.do_classifier_free_guidance,
)

if data.do_classifier_free_guidance:
data.negative_ip_adapter_embeds = []
for i, image_embeds in enumerate(data.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
data.negative_ip_adapter_embeds.append(negative_image_embeds)
data.ip_adapter_embeds[i] = image_embeds
output_hidden_states = [not isinstance(image_proj_layer, ImageProjection) for image_proj_layer in pipeline.unet.encoder_hid_proj.image_projection_layers]
negative_ip_adapter_embeds = []
for (idx, output_hidden_state), ip_adapter_embeds in zip(enumerate(output_hidden_states), data.ip_adapter_embeds):
if not output_hidden_state:
negative_ip_adapter_embed = torch.zeros_like(ip_adapter_embeds)
else:
ip_adapter_image = data.ip_adapter_image[idx] if isinstance(data.ip_adapter_image, list) else data.ip_adapter_image
ip_adapter_image = pipeline.feature_extractor(ip_adapter_image, return_tensors="pt").pixel_values
negative_ip_adapter_embed = pipeline.prepare_ip_adapter_image_embeds(
ip_adapter_image=ip_adapter_image,
ip_adapter_image_embeds=None,
device=data.device,
)
negative_ip_adapter_embeds.append(negative_ip_adapter_embed)
data.negative_ip_adapter_embeds = negative_ip_adapter_embeds

self.add_block_state(state, data)
return pipeline, state
Expand Down Expand Up @@ -340,7 +349,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
data.prompt,
data.prompt_2,
data.device,
1,
data.do_classifier_free_guidance,
data.negative_prompt,
data.negative_prompt_2,
Expand All @@ -350,6 +358,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
negative_pooled_prompt_embeds=None,
lora_scale=data.text_encoder_lora_scale,
clip_skip=data.clip_skip,
force_zeros_for_empty_prompt=self.configs.get('force_zeros_for_empty_prompt', False),
)
# Add outputs
self.add_block_state(state, data)
Expand Down Expand Up @@ -3197,8 +3206,7 @@ def _get_add_time_ids_img2img(

return add_time_ids, add_neg_time_ids

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(self, image, device, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
Expand All @@ -3207,20 +3215,10 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
return image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds
return image_embeds

# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
# 1. return image without apply any guidance
Expand Down Expand Up @@ -3254,13 +3252,11 @@ def prepare_control_image(

return image

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that we could have a new _encode_prompt method ( or maybe better to have a public method for this, so use a different name), and refactor the current encode_prompt to use that new method, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that works and I agree a public method is better for this, how about encode_single_prompt?

Copy link
Collaborator

@yiyixuxu yiyixuxu Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! and we can discuss with the team and change it later if needed (it will be under this PR for now but I think we should change that for regular pipeline too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added encode_single_prompt and encode_prompt.

def encode_prompt(
self,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
Expand All @@ -3270,6 +3266,48 @@ def encode_prompt(
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
force_zeros_for_empty_prompt: bool = False,
):
(
prompt_embeds,
pooled_prompt_embeds,
) = self.encode_single_prompt(
prompt,
prompt_2,
device,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
zero_out_negative_prompt = negative_prompt is None and force_zeros_for_empty_prompt
if do_classifier_free_guidance and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and not zero_out_negative_prompt:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_single_prompt(
negative_prompt,
negative_prompt_2,
device,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

def encode_single_prompt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! do you think we can maybe make this a class method while we are at this?
you will need to pass the components, and lora_scale will not be used if it is None
i.e.

def encode_single_prompt(text_encoder, prompt, device, lora_scale, clip_skip)

(have this in mind that we want to potentially apply this pattern across all pipeline in the future, so maybe look over other encode_prompts to see if this makes seense)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we pass text_encoder, text_encoder_2 etc. or would we call separately with each and have an additional method to handle concatenation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make a encde_single_prompt_clip that can take a list of text_encoders/tokenizers? we're already doing that for sd3/flux https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py#L288

I think class method would add convenience when using with our regular pipeline (i.e, we now support creating a pipeline without unet & vae very well but it's better not having to create one to begin with). But I think it does not matter much with modular diffusers, so we don't have to make class methods too.

all are just thoughts here, I think it depends on the use case we want to support here (e.g. long prompt is an example, or any other customization users want to play with prompts). so let me know what you think

self,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand All @@ -3282,31 +3320,12 @@ def encode_prompt(
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Expand Down Expand Up @@ -3391,92 +3410,11 @@ def encode_prompt(

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)

# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt

# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)

uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]

negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)

max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)

negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

if self.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

if self.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)

if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
Expand All @@ -3487,16 +3425,13 @@ def encode_prompt(
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
return prompt_embeds, pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
self, ip_adapter_image, ip_adapter_image_embeds, device
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
ip_adapter_image_embeds = []
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

Expand All @@ -3509,29 +3444,13 @@ def prepare_ip_adapter_image_embeds(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
single_image_embeds = self.encode_image(
single_ip_adapter_image, device, output_hidden_state
)
ip_adapter_image_embeds.append(single_image_embeds[None, :])

image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)

ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

for single_image_embeds in ip_adapter_image_embeds:
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)

return ip_adapter_image_embeds

Expand Down