Skip to content

Commit

Permalink
Fix dtype error for StableDiffusionXL (#9217)
Browse files Browse the repository at this point in the history
Fix dtype error

Co-authored-by: 蒋硕 <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Aug 20, 2024
1 parent 803e817 commit eda36c4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def unwrap_model(model):

# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)

# time ids
def compute_time_ids(original_size, crops_coords_top_left):
Expand All @@ -1101,7 +1101,7 @@ def compute_time_ids(original_size, crops_coords_top_left):

# Predict the noise residual
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet(
Expand Down

0 comments on commit eda36c4

Please sign in to comment.