Skip to content

ACE++ and UNO integration #7931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ class FluxReduxConditioningField(BaseModel):
)


class FluxUnoReferenceField(BaseModel):
"""A FLUX Uno image list primitive value"""

images: list[ImageField] = Field(description="The images to use as reference for FLUX Uno.")


class FluxFillConditioningField(BaseModel):
"""A FLUX Fill conditioning field."""

Expand Down
44 changes: 44 additions & 0 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy.typing as npt
import torch
import torchvision.transforms as tv_transforms
import torchvision.transforms.functional as TVF
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
Expand All @@ -17,6 +18,7 @@
FluxConditioningField,
FluxFillConditioningField,
FluxReduxConditioningField,
FluxUnoReferenceField,
ImageField,
Input,
InputField,
Expand All @@ -25,6 +27,7 @@
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_uno import preprocess_ref
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
Expand All @@ -45,6 +48,7 @@
get_noise,
get_schedule,
pack,
prepare_multi_ip,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
Expand Down Expand Up @@ -109,6 +113,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="FLUX Redux conditioning tensor.",
input=Input.Connection,
)
uno_ref: FluxUnoReferenceField | None = InputField(
default=None,
description="FLUX Uno reference.",
input=Input.Connection,
)
fill_conditioning: FluxFillConditioningField | None = InputField(
default=None,
description="FLUX Fill conditioning.",
Expand Down Expand Up @@ -284,6 +293,14 @@ def _run_diffusion(

img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)

if self.uno_ref is not None:
# Encode reference images and prepare position ids
uno_ref_imgs = self._prep_uno_reference_imgs(context=context)
uno_ref_imgs, uno_ref_ids = prepare_multi_ip(x, uno_ref_imgs)
else:
uno_ref_imgs = None
uno_ref_ids = None

# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
Expand Down Expand Up @@ -391,6 +408,8 @@ def _run_diffusion(
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
uno_ref_imgs=uno_ref_imgs,
uno_ref_ids=uno_ref_ids,
)

x = unpack(x.float(), self.height, self.width)
Expand Down Expand Up @@ -658,6 +677,30 @@ def _prep_controlnet_extensions(

return controlnet_extensions

def _prep_uno_reference_imgs(self, context: InvocationContext) -> list[torch.Tensor]:
# Load the conditioning image and resize it to the target image size.

assert self.uno_ref is not None, "uno_ref must be set when using UNO."
ref_img_names = [i.image_name for i in self.uno_ref.images]

assert self.controlnet_vae is not None, "Controlnet Vae must be set for UNO encoding"
vae_info = context.models.load(self.controlnet_vae.vae)

ref_latents: list[torch.Tensor] = []

# TODO: Maybe move reference side to UNO Node as parameter
ref_long_side = 512 if len(ref_img_names) <= 1 else 320

for img_name in ref_img_names:
image_pil = context.images.get_pil(img_name, mode="RGB")
image_pil = preprocess_ref(image_pil, ref_long_side) # resize and crop

image_tensor = (TVF.to_tensor(image_pil) * 2.0 - 1.0).unsqueeze(0).float()
ref_latent = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
ref_latents.append(ref_latent)

return ref_latents

def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
if self.control_lora is None:
return None
Expand Down Expand Up @@ -714,6 +757,7 @@ def _prep_flux_fill_img_cond(
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
cond_img = np.array(cond_img)

cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
cond_img = cond_img.to(device=device, dtype=dtype)
Expand Down
71 changes: 71 additions & 0 deletions invokeai/app/invocations/flux_uno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from PIL import Image

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FluxUnoReferenceField, InputField, OutputField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext


def preprocess_ref(raw_image: Image.Image, long_size: int = 512) -> Image.Image:
"""Resize and center crop reference image
Code from https://github.com/bytedance/UNO/blob/main/uno/flux/pipeline.py
"""
# Get the width and height of the original image
image_w, image_h = raw_image.size

# Calculate the long and short sides
if image_w >= image_h:
new_w = long_size
new_h = int((long_size / image_w) * image_h)
else:
new_h = long_size
new_w = int((long_size / image_h) * image_w)

# Scale proportionally to the new width and height
raw_image = raw_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS)
target_w = new_w // 16 * 16
target_h = new_h // 16 * 16

# Calculate the starting coordinates of the clipping to achieve center clipping
left = (new_w - target_w) // 2
top = (new_h - target_h) // 2
right = left + target_w
bottom = top + target_h

# Center crop
raw_image = raw_image.crop((left, top, right, bottom))

# Convert to RGB mode
raw_image = raw_image.convert("RGB")
return raw_image


@invocation_output("flux_uno_output")
class FluxUnoOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Redux invocation."""

uno_ref: FluxUnoReferenceField = OutputField(description="Reference images container", title="Reference images")


@invocation(
"flux_uno",
title="FLUX UNO",
tags=["uno", "control"],
category="ip_adapter",
version="2.1.0",
classification=Classification.Beta,
)
class FluxUnoInvocation(BaseInvocation):
"""Loads a FLUX UNO reference images."""

images: list[ImageField] | None = InputField(default=None, description="The UNO reference images.")

def invoke(self, context: InvocationContext) -> FluxUnoOutput:
uno_ref = FluxUnoReferenceField(images=self.images or [])
return FluxUnoOutput(uno_ref=uno_ref)
Loading