Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def text_encoder_attn_modules(text_encoder):
"QwenImageLoraLoaderMixin",
"ZImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
"LongCatLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
Expand Down Expand Up @@ -121,6 +122,7 @@ def text_encoder_attn_modules(text_encoder):
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LongCatLoraLoaderMixin,
LoraLoaderMixin,
LTX2LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Expand Down
207 changes: 207 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5616,6 +5616,213 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)


class LongCatLoraLoaderMixin(LoraBaseMixin):
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME

@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

if any("dora_scale" in k for k in state_dict):
logger.warning(
"It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

# Flux-family formats
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
else:
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
else:
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
else:
is_fal_kontext = any("base_model" in k for k in state_dict)
if is_fal_kontext:
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)

# Generic non-diffusers formats
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
has_default = any("default." in k for k in state_dict)

if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
converted, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
state_dict = converted
if network_alphas is not None:
metadata = {} if metadata is None else dict(metadata)
metadata["network_alphas"] = network_alphas

# Keep only transformer keys
def _is_non_transformer_key(k: str) -> bool:
bad_prefixes = ("unet.", "text_encoder.", "text_encoder_2.", "vae.", "controlnet.")
if k.startswith(bad_prefixes):
return True
bad_substrings = (".unet.", ".text_encoder.", ".text_encoder_2.", ".vae.", ".controlnet.")
return any(s in k for s in bad_substrings)

state_dict = {k: v for k, v in state_dict.items() if not _is_non_transformer_key(k)}

return (state_dict, metadata) if return_lora_metadata else state_dict

def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. "
"Please update it with `pip install -U peft`."
)

if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. "
"Please update it with `pip install -U peft`."
)

logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
lora_layers = {}
lora_metadata = {}

if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata

if not lora_layers:
raise ValueError("You must pass at least `transformer_lora_layers`.")

cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)

class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LongCatImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...loaders import FromSingleFileMixin, LongCatLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import LongCatImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -202,7 +202,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
class LongCatImagePipeline(DiffusionPipeline, LongCatLoraLoaderMixin, FromSingleFileMixin):
r"""
The pipeline for text-to-image generation.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...loaders import FromSingleFileMixin, LongCatLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import LongCatImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -228,7 +228,7 @@ def calculate_dimensions(target_area, ratio):
return width, height


class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin):
class LongCatImageEditPipeline(DiffusionPipeline, LongCatLoraLoaderMixin, FromSingleFileMixin):
r"""
The LongCat-Image-Edit pipeline for image editing.
"""
Expand Down
Loading