-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[modular] Refactor pipeline functions #10726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -221,15 +221,24 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: | |
ip_adapter_image=data.ip_adapter_image, | ||
ip_adapter_image_embeds=None, | ||
device=data.device, | ||
num_images_per_prompt=1, | ||
do_classifier_free_guidance=data.do_classifier_free_guidance, | ||
) | ||
|
||
if data.do_classifier_free_guidance: | ||
data.negative_ip_adapter_embeds = [] | ||
for i, image_embeds in enumerate(data.ip_adapter_embeds): | ||
negative_image_embeds, image_embeds = image_embeds.chunk(2) | ||
data.negative_ip_adapter_embeds.append(negative_image_embeds) | ||
data.ip_adapter_embeds[i] = image_embeds | ||
output_hidden_states = [not isinstance(image_proj_layer, ImageProjection) for image_proj_layer in pipeline.unet.encoder_hid_proj.image_projection_layers] | ||
negative_ip_adapter_embeds = [] | ||
for (idx, output_hidden_state), ip_adapter_embeds in zip(enumerate(output_hidden_states), data.ip_adapter_embeds): | ||
if not output_hidden_state: | ||
negative_ip_adapter_embed = torch.zeros_like(ip_adapter_embeds) | ||
else: | ||
ip_adapter_image = data.ip_adapter_image[idx] if isinstance(data.ip_adapter_image, list) else data.ip_adapter_image | ||
ip_adapter_image = pipeline.feature_extractor(ip_adapter_image, return_tensors="pt").pixel_values | ||
negative_ip_adapter_embed = pipeline.prepare_ip_adapter_image_embeds( | ||
ip_adapter_image=ip_adapter_image, | ||
ip_adapter_image_embeds=None, | ||
device=data.device, | ||
) | ||
negative_ip_adapter_embeds.append(negative_ip_adapter_embed) | ||
data.negative_ip_adapter_embeds = negative_ip_adapter_embeds | ||
|
||
self.add_block_state(state, data) | ||
return pipeline, state | ||
|
@@ -340,7 +349,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: | |
data.prompt, | ||
data.prompt_2, | ||
data.device, | ||
1, | ||
data.do_classifier_free_guidance, | ||
data.negative_prompt, | ||
data.negative_prompt_2, | ||
|
@@ -350,6 +358,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: | |
negative_pooled_prompt_embeds=None, | ||
lora_scale=data.text_encoder_lora_scale, | ||
clip_skip=data.clip_skip, | ||
force_zeros_for_empty_prompt=self.configs.get('force_zeros_for_empty_prompt', False), | ||
) | ||
# Add outputs | ||
self.add_block_state(state, data) | ||
|
@@ -3197,8 +3206,7 @@ def _get_add_time_ids_img2img( | |
|
||
return add_time_ids, add_neg_time_ids | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image | ||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): | ||
def encode_image(self, image, device, output_hidden_states=None): | ||
dtype = next(self.image_encoder.parameters()).dtype | ||
|
||
if not isinstance(image, torch.Tensor): | ||
|
@@ -3207,20 +3215,10 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state | |
image = image.to(device=device, dtype=dtype) | ||
if output_hidden_states: | ||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | ||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) | ||
uncond_image_enc_hidden_states = self.image_encoder( | ||
torch.zeros_like(image), output_hidden_states=True | ||
).hidden_states[-2] | ||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( | ||
num_images_per_prompt, dim=0 | ||
) | ||
return image_enc_hidden_states, uncond_image_enc_hidden_states | ||
return image_enc_hidden_states | ||
else: | ||
image_embeds = self.image_encoder(image).image_embeds | ||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | ||
uncond_image_embeds = torch.zeros_like(image_embeds) | ||
|
||
return image_embeds, uncond_image_embeds | ||
return image_embeds | ||
|
||
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image | ||
# 1. return image without apply any guidance | ||
|
@@ -3254,13 +3252,11 @@ def prepare_control_image( | |
|
||
return image | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt | ||
def encode_prompt( | ||
self, | ||
prompt: str, | ||
prompt_2: Optional[str] = None, | ||
device: Optional[torch.device] = None, | ||
num_images_per_prompt: int = 1, | ||
do_classifier_free_guidance: bool = True, | ||
negative_prompt: Optional[str] = None, | ||
negative_prompt_2: Optional[str] = None, | ||
|
@@ -3270,6 +3266,48 @@ def encode_prompt( | |
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, | ||
lora_scale: Optional[float] = None, | ||
clip_skip: Optional[int] = None, | ||
force_zeros_for_empty_prompt: bool = False, | ||
): | ||
( | ||
prompt_embeds, | ||
pooled_prompt_embeds, | ||
) = self.encode_single_prompt( | ||
prompt, | ||
prompt_2, | ||
device, | ||
prompt_embeds=prompt_embeds, | ||
pooled_prompt_embeds=pooled_prompt_embeds, | ||
lora_scale=lora_scale, | ||
clip_skip=clip_skip, | ||
) | ||
zero_out_negative_prompt = negative_prompt is None and force_zeros_for_empty_prompt | ||
if do_classifier_free_guidance and zero_out_negative_prompt: | ||
negative_prompt_embeds = torch.zeros_like(prompt_embeds) | ||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | ||
elif do_classifier_free_guidance and not zero_out_negative_prompt: | ||
( | ||
negative_prompt_embeds, | ||
negative_pooled_prompt_embeds, | ||
) = self.encode_single_prompt( | ||
negative_prompt, | ||
negative_prompt_2, | ||
device, | ||
prompt_embeds=negative_prompt_embeds, | ||
pooled_prompt_embeds=negative_pooled_prompt_embeds, | ||
lora_scale=lora_scale, | ||
clip_skip=clip_skip, | ||
) | ||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds | ||
|
||
def encode_single_prompt( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! do you think we can maybe make this a class method while we are at this? def encode_single_prompt(text_encoder, prompt, device, lora_scale, clip_skip) (have this in mind that we want to potentially apply this pattern across all pipeline in the future, so maybe look over other encode_prompts to see if this makes seense) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would we pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe make a I think class method would add convenience when using with our regular pipeline (i.e, we now support creating a pipeline without unet & vae very well but it's better not having to create one to begin with). But I think it does not matter much with modular diffusers, so we don't have to make class methods too. all are just thoughts here, I think it depends on the use case we want to support here (e.g. long prompt is an example, or any other customization users want to play with prompts). so let me know what you think |
||
self, | ||
prompt: str, | ||
prompt_2: Optional[str] = None, | ||
device: Optional[torch.device] = None, | ||
prompt_embeds: Optional[torch.Tensor] = None, | ||
pooled_prompt_embeds: Optional[torch.Tensor] = None, | ||
lora_scale: Optional[float] = None, | ||
clip_skip: Optional[int] = None, | ||
): | ||
r""" | ||
Encodes the prompt into text encoder hidden states. | ||
|
@@ -3282,31 +3320,12 @@ def encode_prompt( | |
used in both text-encoders | ||
device: (`torch.device`): | ||
torch device | ||
num_images_per_prompt (`int`): | ||
number of images that should be generated per prompt | ||
do_classifier_free_guidance (`bool`): | ||
whether to use classifier free guidance or not | ||
negative_prompt (`str` or `List[str]`, *optional*): | ||
The prompt or prompts not to guide the image generation. If not defined, one has to pass | ||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | ||
less than `1`). | ||
negative_prompt_2 (`str` or `List[str]`, *optional*): | ||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and | ||
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders | ||
prompt_embeds (`torch.Tensor`, *optional*): | ||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | ||
provided, text embeddings will be generated from `prompt` input argument. | ||
negative_prompt_embeds (`torch.Tensor`, *optional*): | ||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | ||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | ||
argument. | ||
pooled_prompt_embeds (`torch.Tensor`, *optional*): | ||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | ||
If not provided, pooled text embeddings will be generated from `prompt` input argument. | ||
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): | ||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | ||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` | ||
input argument. | ||
lora_scale (`float`, *optional*): | ||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | ||
clip_skip (`int`, *optional*): | ||
|
@@ -3391,92 +3410,11 @@ def encode_prompt( | |
|
||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | ||
|
||
# get unconditional embeddings for classifier free guidance | ||
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt | ||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: | ||
negative_prompt_embeds = torch.zeros_like(prompt_embeds) | ||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | ||
elif do_classifier_free_guidance and negative_prompt_embeds is None: | ||
negative_prompt = negative_prompt or "" | ||
negative_prompt_2 = negative_prompt_2 or negative_prompt | ||
|
||
# normalize str to list | ||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | ||
negative_prompt_2 = ( | ||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 | ||
) | ||
|
||
uncond_tokens: List[str] | ||
if prompt is not None and type(prompt) is not type(negative_prompt): | ||
raise TypeError( | ||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | ||
f" {type(prompt)}." | ||
) | ||
elif batch_size != len(negative_prompt): | ||
raise ValueError( | ||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | ||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | ||
" the batch size of `prompt`." | ||
) | ||
else: | ||
uncond_tokens = [negative_prompt, negative_prompt_2] | ||
|
||
negative_prompt_embeds_list = [] | ||
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): | ||
if isinstance(self, TextualInversionLoaderMixin): | ||
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) | ||
|
||
max_length = prompt_embeds.shape[1] | ||
uncond_input = tokenizer( | ||
negative_prompt, | ||
padding="max_length", | ||
max_length=max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
|
||
negative_prompt_embeds = text_encoder( | ||
uncond_input.input_ids.to(device), | ||
output_hidden_states=True, | ||
) | ||
# We are only ALWAYS interested in the pooled output of the final text encoder | ||
negative_pooled_prompt_embeds = negative_prompt_embeds[0] | ||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] | ||
|
||
negative_prompt_embeds_list.append(negative_prompt_embeds) | ||
|
||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) | ||
|
||
if self.text_encoder_2 is not None: | ||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | ||
else: | ||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) | ||
|
||
bs_embed, seq_len, _ = prompt_embeds.shape | ||
# duplicate text embeddings for each generation per prompt, using mps friendly method | ||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | ||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | ||
|
||
if do_classifier_free_guidance: | ||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method | ||
seq_len = negative_prompt_embeds.shape[1] | ||
|
||
if self.text_encoder_2 is not None: | ||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | ||
else: | ||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) | ||
|
||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | ||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | ||
|
||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | ||
bs_embed * num_images_per_prompt, -1 | ||
) | ||
if do_classifier_free_guidance: | ||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | ||
bs_embed * num_images_per_prompt, -1 | ||
) | ||
|
||
if self.text_encoder is not None: | ||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | ||
# Retrieve the original scale by scaling back the LoRA layers | ||
|
@@ -3487,16 +3425,13 @@ def encode_prompt( | |
# Retrieve the original scale by scaling back the LoRA layers | ||
unscale_lora_layers(self.text_encoder_2, lora_scale) | ||
|
||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds | ||
return prompt_embeds, pooled_prompt_embeds | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds | ||
def prepare_ip_adapter_image_embeds( | ||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance | ||
self, ip_adapter_image, ip_adapter_image_embeds, device | ||
): | ||
image_embeds = [] | ||
if do_classifier_free_guidance: | ||
negative_image_embeds = [] | ||
if ip_adapter_image_embeds is None: | ||
ip_adapter_image_embeds = [] | ||
if not isinstance(ip_adapter_image, list): | ||
ip_adapter_image = [ip_adapter_image] | ||
|
||
|
@@ -3509,29 +3444,13 @@ def prepare_ip_adapter_image_embeds( | |
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers | ||
): | ||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection) | ||
single_image_embeds, single_negative_image_embeds = self.encode_image( | ||
single_ip_adapter_image, device, 1, output_hidden_state | ||
single_image_embeds = self.encode_image( | ||
single_ip_adapter_image, device, output_hidden_state | ||
) | ||
ip_adapter_image_embeds.append(single_image_embeds[None, :]) | ||
|
||
image_embeds.append(single_image_embeds[None, :]) | ||
if do_classifier_free_guidance: | ||
negative_image_embeds.append(single_negative_image_embeds[None, :]) | ||
else: | ||
for single_image_embeds in ip_adapter_image_embeds: | ||
if do_classifier_free_guidance: | ||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) | ||
negative_image_embeds.append(single_negative_image_embeds) | ||
image_embeds.append(single_image_embeds) | ||
|
||
ip_adapter_image_embeds = [] | ||
for i, single_image_embeds in enumerate(image_embeds): | ||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) | ||
if do_classifier_free_guidance: | ||
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) | ||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) | ||
|
||
for single_image_embeds in ip_adapter_image_embeds: | ||
single_image_embeds = single_image_embeds.to(device=device) | ||
ip_adapter_image_embeds.append(single_image_embeds) | ||
|
||
return ip_adapter_image_embeds | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that we could have a new
_encode_prompt
method ( or maybe better to have a public method for this, so use a different name), and refactor the currentencode_prompt
to use that new method, what do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that works and I agree a public method is better for this, how about
encode_single_prompt
?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good! and we can discuss with the team and change it later if needed (it will be under this PR for now but I think we should change that for regular pipeline too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added
encode_single_prompt
andencode_prompt
.