From b4bf0e25067f7059265a0ed23c700547670e6cb2 Mon Sep 17 00:00:00 2001 From: Ross Morgan-Linial Date: Wed, 20 Dec 2023 11:04:36 -0800 Subject: [PATCH] Fix VAE checkpoint export reshape_weight_for_sd() was being applied to biases as well as weights. This is not the most elegant fix, but it is the simplest. --- dreambooth/diff_to_sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dreambooth/diff_to_sd.py b/dreambooth/diff_to_sd.py index a190b5f5..4ce16c3c 100644 --- a/dreambooth/diff_to_sd.py +++ b/dreambooth/diff_to_sd.py @@ -197,7 +197,7 @@ def convert_vae_state_dict(vae_state_dict): v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] + weights_to_convert = ["q", "k", "v", "proj_out", "to_q", "to_k", "to_v", "to_out.0"] keys_to_rename = {} for k, v in new_state_dict.items(): for weight_name in weights_to_convert: @@ -211,7 +211,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in keys_to_rename.items(): if k in new_state_dict: print(f"Renaming {k} to {v}") - new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k]) + new_state_dict[v] = new_state_dict[k] del new_state_dict[k] return new_state_dict