Skip to content

Support for control-lora #10686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/models/controlnets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ..attention_processor import (
Expand Down Expand Up @@ -106,7 +107,7 @@ def forward(self, conditioning):
return embedding


class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
"""
A ControlNet model.

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
from .remote_utils import remote_decode
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_control_lora_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
Expand Down
179 changes: 179 additions & 0 deletions src/diffusers/utils/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,36 @@ class StateDictType(enum.Enum):
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
}

CONTROL_LORA_TO_DIFFUSERS = {
".to_q.down": ".to_q.lora_A.weight",
".to_q.up": ".to_q.lora_B.weight",
".to_k.down": ".to_k.lora_A.weight",
".to_k.up": ".to_k.lora_B.weight",
".to_v.down": ".to_v.lora_A.weight",
".to_v.up": ".to_v.lora_B.weight",
".to_out.0.down": ".to_out.0.lora_A.weight",
".to_out.0.up": ".to_out.0.lora_B.weight",
".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight",
".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight",
".ff.net.2.down": ".ff.net.2.lora_A.weight",
".ff.net.2.up": ".ff.net.2.lora_B.weight",
".proj_in.down": ".proj_in.lora_A.weight",
".proj_in.up": ".proj_in.lora_B.weight",
".proj_out.down": ".proj_out.lora_A.weight",
".proj_out.up": ".proj_out.lora_B.weight",
".conv.down": ".conv.lora_A.weight",
".conv.up": ".conv.lora_B.weight",
**{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)},
**{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)},
"conv_in.down": "conv_in.lora_A.weight",
"conv_in.up": "conv_in.lora_B.weight",
".conv_shortcut.down": ".conv_shortcut.lora_A.weight",
".conv_shortcut.up": ".conv_shortcut.lora_B.weight",
**{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)},
**{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)},
"time_emb_proj.down": "time_emb_proj.lora_A.weight",
"time_emb_proj.up": "time_emb_proj.lora_B.weight",
}

DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
Expand Down Expand Up @@ -258,6 +288,155 @@ def convert_unet_state_dict_to_peft(state_dict):
return convert_state_dict(state_dict, mapping)


def convert_control_lora_state_dict_to_peft(state_dict):
def _convert_controlnet_to_diffusers(state_dict):
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")

# Retrieves the keys for the input blocks only
num_input_blocks = len(
{".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}
)
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
layers_per_block = 2

# op blocks
op_blocks = [key for key in state_dict if "0.op" in key]

converted_state_dict = {}
# Conv in layers
for key in input_blocks[0]:
diffusers_key = key.replace("input_blocks.0.0", "conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)

# controlnet time embedding blocks
time_embedding_blocks = [key for key in state_dict if "time_embed" in key]
for key in time_embedding_blocks:
diffusers_key = (key.replace("time_embed.0", "time_embedding.linear_1")
.replace("time_embed.2", "time_embedding.linear_2")
)
converted_state_dict[diffusers_key] = state_dict.get(key)

# controlnet label embedding blocks
label_embedding_blocks = [key for key in state_dict if "label_emb" in key]
for key in label_embedding_blocks:
diffusers_key = (key.replace("label_emb.0.0", "add_embedding.linear_1")
.replace("label_emb.0.2", "add_embedding.linear_2")
)
converted_state_dict[diffusers_key] = state_dict.get(key)

# Down blocks
for i in range(1, num_input_blocks):
block_id = (i - 1) // (layers_per_block + 1)
layer_in_block_id = (i - 1) % (layers_per_block + 1)

resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
for key in resnets:
diffusers_key = (key.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key = diffusers_key.replace(
f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)

if f"input_blocks.{i}.0.op.bias" in state_dict:
for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]:
diffusers_key = key.replace(f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv")
converted_state_dict[diffusers_key] = state_dict.get(key)

attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
for key in attentions:
diffusers_key = key.replace(
f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)

# controlnet down blocks
for i in range(num_input_blocks):
converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight")
converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias")

# Retrieves the keys for the middle blocks only
num_middle_blocks = len(
{".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer}
)
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}

# Mid blocks
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
for k in middle_blocks[key]:
diffusers_key_hf = (k.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key_hf = diffusers_key_hf.replace(
f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}"
)
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
else:
for k in middle_blocks[key]:
diffusers_key_hf = k.replace(
f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}"
)
converted_state_dict[diffusers_key_hf] = state_dict.get(k)

# mid block
converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight")
converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias")

# controlnet cond embedding blocks
cond_embedding_blocks = {
".".join(layer.split(".")[:2])
for layer in state_dict
if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
}
num_cond_embedding_blocks = len(cond_embedding_blocks)

for idx in range(1, num_cond_embedding_blocks + 1):
diffusers_idx = idx - 1
cond_block_id = 2 * idx

converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get(
f"input_hint_block.{cond_block_id}.weight"
)
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get(
f"input_hint_block.{cond_block_id}.bias"
)

for key in [key for key in state_dict if "input_hint_block.0" in key]:
diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)

for key in [key for key in state_dict if "input_hint_block.14" in key]:
diffusers_key = key.replace(f"input_hint_block.14", "controlnet_cond_embedding.conv_out")
converted_state_dict[diffusers_key] = state_dict.get(key)

return converted_state_dict

state_dict = _convert_controlnet_to_diffusers(state_dict)
mapping = CONTROL_LORA_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)


def convert_all_state_dict_to_peft(state_dict):
r"""
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
Expand Down