Skip to content

Commit 31de879

Browse files
sayakpaulyiyixuxu
andauthored
[IP2P] Make text encoder truly optional in InstructPi2Pix (#6995)
* make text encoder component truly optional. * more fixes * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 07349c2 commit 31de879

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -553,13 +553,15 @@ def _encode_prompt(
553553
else:
554554
attention_mask = None
555555

556-
prompt_embeds = self.text_encoder(
557-
text_input_ids.to(device),
558-
attention_mask=attention_mask,
559-
)
556+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
560557
prompt_embeds = prompt_embeds[0]
561558

562-
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
559+
if self.text_encoder is not None:
560+
prompt_embeds_dtype = self.text_encoder.dtype
561+
else:
562+
prompt_embeds_dtype = self.unet.dtype
563+
564+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
563565

564566
bs_embed, seq_len, _ = prompt_embeds.shape
565567
# duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -615,7 +617,7 @@ def _encode_prompt(
615617
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
616618
seq_len = negative_prompt_embeds.shape[1]
617619

618-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
620+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
619621

620622
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
621623
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

0 commit comments

Comments
 (0)