@@ -553,13 +553,15 @@ def _encode_prompt(
553
553
else :
554
554
attention_mask = None
555
555
556
- prompt_embeds = self .text_encoder (
557
- text_input_ids .to (device ),
558
- attention_mask = attention_mask ,
559
- )
556
+ prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
560
557
prompt_embeds = prompt_embeds [0 ]
561
558
562
- prompt_embeds = prompt_embeds .to (dtype = self .text_encoder .dtype , device = device )
559
+ if self .text_encoder is not None :
560
+ prompt_embeds_dtype = self .text_encoder .dtype
561
+ else :
562
+ prompt_embeds_dtype = self .unet .dtype
563
+
564
+ prompt_embeds = prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
563
565
564
566
bs_embed , seq_len , _ = prompt_embeds .shape
565
567
# duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -615,7 +617,7 @@ def _encode_prompt(
615
617
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
616
618
seq_len = negative_prompt_embeds .shape [1 ]
617
619
618
- negative_prompt_embeds = negative_prompt_embeds .to (dtype = self . text_encoder . dtype , device = device )
620
+ negative_prompt_embeds = negative_prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
619
621
620
622
negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
621
623
negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
0 commit comments