Skip to content

Commit 1450c2a

Browse files
guiyrthlky
andauthored
Multi IP-Adapter for Flux pipelines (#10867)
* Initial implementation of Flux multi IP-Adapter * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <[email protected]> * Changes for ipa image embeds * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <[email protected]> * make style && make quality * Updated ip_adapter test * Created typing_utils.py --------- Co-authored-by: hlky <[email protected]>
1 parent cc7b5b8 commit 1450c2a

File tree

9 files changed

+193
-110
lines changed

9 files changed

+193
-110
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
2424
from ..utils import (
2525
USE_PEFT_BACKEND,
26+
_get_detailed_type,
2627
_get_model_file,
28+
_is_valid_type,
2729
is_accelerate_available,
2830
is_torch_version,
2931
is_transformers_available,
@@ -577,29 +579,36 @@ def LinearStrengthModel(start, finish, size):
577579
pipeline.set_ip_adapter_scale(ip_strengths)
578580
```
579581
"""
580-
transformer = self.transformer
581-
if not isinstance(scale, list):
582-
scale = [[scale] * transformer.config.num_layers]
583-
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
584-
if len(scale) != transformer.config.num_layers:
585-
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
582+
583+
scale_type = Union[int, float]
584+
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
585+
num_layers = self.transformer.config.num_layers
586+
587+
# Single value for all layers of all IP-Adapters
588+
if isinstance(scale, scale_type):
589+
scale = [scale for _ in range(num_ip_adapters)]
590+
# List of per-layer scales for a single IP-Adapter
591+
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
586592
scale = [scale]
593+
# Invalid scale type
594+
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
595+
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
587596

588-
scale_configs = scale
597+
if len(scale) != num_ip_adapters:
598+
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
589599

590-
key_id = 0
591-
for attn_name, attn_processor in transformer.attn_processors.items():
592-
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
593-
if len(scale_configs) != len(attn_processor.scale):
594-
raise ValueError(
595-
f"Cannot assign {len(scale_configs)} scale_configs to "
596-
f"{len(attn_processor.scale)} IP-Adapter."
597-
)
598-
elif len(scale_configs) == 1:
599-
scale_configs = scale_configs * len(attn_processor.scale)
600-
for i, scale_config in enumerate(scale_configs):
601-
attn_processor.scale[i] = scale_config[key_id]
602-
key_id += 1
600+
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
601+
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
602+
raise ValueError(
603+
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
604+
)
605+
606+
# Scalars are transformed to lists with length num_layers
607+
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
608+
609+
# Set scales. zip over scale_configs prevents going into single transformer layers
610+
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
611+
attn_processor.scale = scale
603612

604613
def unload_ip_adapter(self):
605614
"""

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,9 +2780,8 @@ def __call__(
27802780

27812781
# IP-adapter
27822782
ip_query = hidden_states_query_proj
2783-
ip_attn_output = None
2784-
# for ip-adapter
2785-
# TODO: support for multiple adapters
2783+
ip_attn_output = torch.zeros_like(hidden_states)
2784+
27862785
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
27872786
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
27882787
):
@@ -2793,12 +2792,14 @@ def __call__(
27932792
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
27942793
# the output of sdp = (batch, num_heads, seq_len, head_dim)
27952794
# TODO: add support for attn.scale when we move to Torch 2.1
2796-
ip_attn_output = F.scaled_dot_product_attention(
2795+
current_ip_hidden_states = F.scaled_dot_product_attention(
27972796
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
27982797
)
2799-
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2800-
ip_attn_output = scale * ip_attn_output
2801-
ip_attn_output = ip_attn_output.to(ip_query.dtype)
2798+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2799+
batch_size, -1, attn.heads * head_dim
2800+
)
2801+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2802+
ip_attn_output += scale * current_ip_hidden_states
28022803

28032804
return hidden_states, encoder_hidden_states, ip_attn_output
28042805
else:

src/diffusers/models/embeddings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,11 @@ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[
25832583
super().__init__()
25842584
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
25852585

2586+
@property
2587+
def num_ip_adapters(self) -> int:
2588+
"""Number of IP-Adapters loaded."""
2589+
return len(self.image_projection_layers)
2590+
25862591
def forward(self, image_embeds: List[torch.Tensor]):
25872592
projected_image_embeds = []
25882593

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,23 +405,28 @@ def prepare_ip_adapter_image_embeds(
405405
if not isinstance(ip_adapter_image, list):
406406
ip_adapter_image = [ip_adapter_image]
407407

408-
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
408+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
409409
raise ValueError(
410-
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
410+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
411411
)
412412

413-
for single_ip_adapter_image, image_proj_layer in zip(
414-
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
415-
):
413+
for single_ip_adapter_image in ip_adapter_image:
416414
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
417-
418415
image_embeds.append(single_image_embeds[None, :])
419416
else:
417+
if not isinstance(ip_adapter_image_embeds, list):
418+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
419+
420+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
421+
raise ValueError(
422+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
423+
)
424+
420425
for single_image_embeds in ip_adapter_image_embeds:
421426
image_embeds.append(single_image_embeds)
422427

423428
ip_adapter_image_embeds = []
424-
for i, single_image_embeds in enumerate(image_embeds):
429+
for single_image_embeds in image_embeds:
425430
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
426431
single_image_embeds = single_image_embeds.to(device=device)
427432
ip_adapter_image_embeds.append(single_image_embeds)
@@ -872,10 +877,13 @@ def __call__(
872877
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
873878
):
874879
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
880+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
881+
875882
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
876883
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
877884
):
878885
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
886+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
879887

880888
if self.joint_attention_kwargs is None:
881889
self._joint_attention_kwargs = {}

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
import warnings
1919
from pathlib import Path
20-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
20+
from typing import Any, Callable, Dict, List, Optional, Union
2121

2222
import requests
2323
import torch
@@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
10591059
break
10601060
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
10611061
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
1062-
1063-
1064-
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
1065-
"""
1066-
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
1067-
the correct type as well.
1068-
"""
1069-
if not isinstance(class_or_tuple, tuple):
1070-
class_or_tuple = (class_or_tuple,)
1071-
1072-
# Unpack unions
1073-
unpacked_class_or_tuple = []
1074-
for t in class_or_tuple:
1075-
if get_origin(t) is Union:
1076-
unpacked_class_or_tuple.extend(get_args(t))
1077-
else:
1078-
unpacked_class_or_tuple.append(t)
1079-
class_or_tuple = tuple(unpacked_class_or_tuple)
1080-
1081-
if Any in class_or_tuple:
1082-
return True
1083-
1084-
obj_type = type(obj)
1085-
# Classes with obj's type
1086-
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
1087-
1088-
# Singular types (e.g. int, ControlNet, ...)
1089-
# Untyped collections (e.g. List, but not List[int])
1090-
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
1091-
if () in elem_class_or_tuple:
1092-
return True
1093-
# Typed lists or sets
1094-
elif obj_type in (list, set):
1095-
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
1096-
# Typed tuples
1097-
elif obj_type is tuple:
1098-
return any(
1099-
# Tuples with any length and single type (e.g. Tuple[int, ...])
1100-
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
1101-
or
1102-
# Tuples with fixed length and any types (e.g. Tuple[int, str])
1103-
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
1104-
for t in elem_class_or_tuple
1105-
)
1106-
# Typed dicts
1107-
elif obj_type is dict:
1108-
return any(
1109-
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
1110-
for kt, vt in elem_class_or_tuple
1111-
)
1112-
1113-
else:
1114-
return False
1115-
1116-
1117-
def _get_detailed_type(obj: Any) -> Type:
1118-
"""
1119-
Gets a detailed type for an object, including nested types for collections.
1120-
"""
1121-
obj_type = type(obj)
1122-
1123-
if obj_type in (list, set):
1124-
obj_origin_type = List if obj_type is list else Set
1125-
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
1126-
return obj_origin_type[elems_type]
1127-
elif obj_type is tuple:
1128-
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
1129-
elif obj_type is dict:
1130-
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
1131-
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
1132-
return Dict[keys_type, values_type]
1133-
else:
1134-
return obj_type

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
DEPRECATED_REVISION_ARGS,
5555
BaseOutput,
5656
PushToHubMixin,
57+
_get_detailed_type,
58+
_is_valid_type,
5759
is_accelerate_available,
5860
is_accelerate_version,
5961
is_torch_npu_available,
@@ -78,12 +80,10 @@
7880
_fetch_class_library_tuple,
7981
_get_custom_components_and_folders,
8082
_get_custom_pipeline_class,
81-
_get_detailed_type,
8283
_get_final_device_map,
8384
_get_ignore_patterns,
8485
_get_pipeline_class,
8586
_identify_model_variants,
86-
_is_valid_type,
8787
_maybe_raise_error_for_incorrect_transformers,
8888
_maybe_raise_warning_for_inpainting,
8989
_resolve_custom_pipeline_and_cls,

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
convert_state_dict_to_peft,
124124
convert_unet_state_dict_to_peft,
125125
)
126+
from .typing_utils import _get_detailed_type, _is_valid_type
126127

127128

128129
logger = get_logger(__name__)

src/diffusers/utils/typing_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Typing utilities: Utilities related to type checking and validation
16+
"""
17+
18+
from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args, get_origin
19+
20+
21+
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
22+
"""
23+
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
24+
the correct type as well.
25+
"""
26+
if not isinstance(class_or_tuple, tuple):
27+
class_or_tuple = (class_or_tuple,)
28+
29+
# Unpack unions
30+
unpacked_class_or_tuple = []
31+
for t in class_or_tuple:
32+
if get_origin(t) is Union:
33+
unpacked_class_or_tuple.extend(get_args(t))
34+
else:
35+
unpacked_class_or_tuple.append(t)
36+
class_or_tuple = tuple(unpacked_class_or_tuple)
37+
38+
if Any in class_or_tuple:
39+
return True
40+
41+
obj_type = type(obj)
42+
# Classes with obj's type
43+
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
44+
45+
# Singular types (e.g. int, ControlNet, ...)
46+
# Untyped collections (e.g. List, but not List[int])
47+
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
48+
if () in elem_class_or_tuple:
49+
return True
50+
# Typed lists or sets
51+
elif obj_type in (list, set):
52+
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
53+
# Typed tuples
54+
elif obj_type is tuple:
55+
return any(
56+
# Tuples with any length and single type (e.g. Tuple[int, ...])
57+
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
58+
or
59+
# Tuples with fixed length and any types (e.g. Tuple[int, str])
60+
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
61+
for t in elem_class_or_tuple
62+
)
63+
# Typed dicts
64+
elif obj_type is dict:
65+
return any(
66+
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
67+
for kt, vt in elem_class_or_tuple
68+
)
69+
70+
else:
71+
return False
72+
73+
74+
def _get_detailed_type(obj: Any) -> Type:
75+
"""
76+
Gets a detailed type for an object, including nested types for collections.
77+
"""
78+
obj_type = type(obj)
79+
80+
if obj_type in (list, set):
81+
obj_origin_type = List if obj_type is list else Set
82+
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
83+
return obj_origin_type[elems_type]
84+
elif obj_type is tuple:
85+
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
86+
elif obj_type is dict:
87+
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
88+
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
89+
return Dict[keys_type, values_type]
90+
else:
91+
return obj_type

0 commit comments

Comments
 (0)