diff --git a/dreambooth/diff_to_sd.py b/dreambooth/diff_to_sd.py index a190b5f5..4ce16c3c 100644 --- a/dreambooth/diff_to_sd.py +++ b/dreambooth/diff_to_sd.py @@ -197,7 +197,7 @@ def convert_vae_state_dict(vae_state_dict): v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] + weights_to_convert = ["q", "k", "v", "proj_out", "to_q", "to_k", "to_v", "to_out.0"] keys_to_rename = {} for k, v in new_state_dict.items(): for weight_name in weights_to_convert: @@ -211,7 +211,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in keys_to_rename.items(): if k in new_state_dict: print(f"Renaming {k} to {v}") - new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k]) + new_state_dict[v] = new_state_dict[k] del new_state_dict[k] return new_state_dict