Skip to content

Commit

Permalink
Merge pull request #1414 from RossM/vae_export_fix
Browse files Browse the repository at this point in the history
Fix VAE checkpoint export
  • Loading branch information
d8ahazard authored Dec 20, 2023
2 parents 0c434ad + b4bf0e2 commit 113708f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dreambooth/diff_to_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def convert_vae_state_dict(vae_state_dict):
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
weights_to_convert = ["q", "k", "v", "proj_out", "to_q", "to_k", "to_v", "to_out.0"]
keys_to_rename = {}
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
Expand All @@ -211,7 +211,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in keys_to_rename.items():
if k in new_state_dict:
print(f"Renaming {k} to {v}")
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
new_state_dict[v] = new_state_dict[k]
del new_state_dict[k]
return new_state_dict

Expand Down

0 comments on commit 113708f

Please sign in to comment.