Skip to content

Commit eda36c4

Browse files
leisuzz蒋硕sayakpaul
authored
Fix dtype error for StableDiffusionXL (#9217)
Fix dtype error Co-authored-by: 蒋硕 <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 803e817 commit eda36c4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ def unwrap_model(model):
10841084

10851085
# Add noise to the model input according to the noise magnitude at each timestep
10861086
# (this is the forward diffusion process)
1087-
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1087+
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
10881088

10891089
# time ids
10901090
def compute_time_ids(original_size, crops_coords_top_left):
@@ -1101,7 +1101,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
11011101

11021102
# Predict the noise residual
11031103
unet_added_conditions = {"time_ids": add_time_ids}
1104-
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
1104+
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
11051105
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
11061106
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
11071107
model_pred = unet(

0 commit comments

Comments
 (0)