Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def load_target_model(self, args, weight_dtype, accelerator):

ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)

# Apply partitioned for Diffusion4k
if args.partitioned_vae:
ae.decoder.partitioned = True
ae.decoder.stride = 2 # Diffusion4k stride

return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

def get_tokenize_strategy(self, args):
Expand Down Expand Up @@ -359,8 +364,14 @@ def get_noise_pred_and_target(

# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
latent_height, latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]

if args.partitioned_vae:
img_ids = flux_utils.prepare_img_ids(bsz, latent_height // 2, latent_width // 2).to(device=accelerator.device)
else:
img_ids = flux_utils.prepare_img_ids(bsz, latent_height // 2, latent_width // 2).to(device=accelerator.device)

assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids"

# get guidance
# ensure guidance_scale in args is float
Expand Down Expand Up @@ -408,7 +419,11 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
)

# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# if args.partitioned_vae:
# model_pred = flux_utils.unpack_partitioned_latents(model_pred, latent_width, latent_height)
# else:
# # unpack latents
model_pred = flux_utils.unpack_latents(model_pred, latents.shape[2] // 2, latents.shape[3] // 2)

# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
Expand Down
67 changes: 63 additions & 4 deletions library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class AutoEncoderParams:
z_channels: int
scale_factor: float
shift_factor: float
stride: int
partitioned: bool


def swish(x: Tensor) -> Tensor:
Expand Down Expand Up @@ -228,6 +230,8 @@ def __init__(
in_channels: int,
resolution: int,
z_channels: int,
partitioned=False,
stride=1,
):
super().__init__()
self.ch = ch
Expand All @@ -236,6 +240,8 @@ def __init__(
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
self.stride = stride
self.partitioned = partitioned

# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
Expand Down Expand Up @@ -272,7 +278,7 @@ def __init__(
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, z: Tensor) -> Tensor:
def forward(self, z: Tensor, partitioned=None) -> Tensor:
# z to block_in
h = self.conv_in(z)

Expand All @@ -291,9 +297,56 @@ def forward(self, z: Tensor) -> Tensor:
h = self.up[i_level].upsample(h)

# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)

# Diffusion4k
partitioned = partitioned if not None else self.partitioned
if self.stride > 1 and partitioned:
h = self.norm_out(h)
h = swish(h)

overlap_size = 1 # because last conv kernel_size = 3
res = []
partitioned_height = h.shape[2] // self.stride
partitioned_width = h.shape[3] // self.stride

assert self.stride == 2 # only support stride = 2 for now
rows = []
for i in range(0, h.shape[2], partitioned_height):
row = []
for j in range(0, h.shape[3], partitioned_width):
partition = h[:,:, max(i - overlap_size, 0) : min(i + partitioned_height + overlap_size, h.shape[2]), max(j - overlap_size, 0) : min(j + partitioned_width + overlap_size, h.shape[3])]

# for strih
if i==0 and j==0:
partition = torch.nn.functional.pad(partition, (1, 0, 1, 0), "constant", 0)
elif i==0:
partition = torch.nn.functional.pad(partition, (0, 1, 1, 0), "constant", 0)
elif i>0 and j==0:
partition = torch.nn.functional.pad(partition, (1, 0, 0, 1), "constant", 0)
elif i>0 and j>0:
partition = torch.nn.functional.pad(partition, (0, 1, 0, 1), "constant", 0)

partition = torch.nn.functional.interpolate(partition, scale_factor=self.stride, mode='nearest')
partition = self.conv_out(partition)
partition = partition[:,:,overlap_size:partitioned_height*2+overlap_size,overlap_size:partitioned_width*2+overlap_size]

row.append(partition)
rows.append(row)

for row in rows:
res.append(torch.cat(row, dim=3))

h = torch.cat(res, dim=2)
# Diffusion4k
elif self.stride > 1:
h = self.norm_out(h)
h = torch.nn.functional.interpolate(h, scale_factor=self.stride, mode='nearest')
h = swish(h)
h = self.conv_out(h)
else:
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h


Expand Down Expand Up @@ -404,6 +457,9 @@ class ModelSpec:
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
# Diffusion4k
stride=1,
partitioned=False,
),
),
"schnell": ModelSpec(
Expand Down Expand Up @@ -436,6 +492,9 @@ class ModelSpec:
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
# Diffusion4k
stride=1,
partitioned=False,
),
),
}
Expand Down
64 changes: 50 additions & 14 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,49 @@ def encode_prompt(prpt):

# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16
packed_latent_width = width // 16
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)

if args.partitioned_vae:
vae_scale_factor = 32
latent_height = 2 * (int(height) // vae_scale_factor)
latent_width = 2 * (int(width) // vae_scale_factor)

print("latent height", latent_height)
print("latent width", latent_width)

noisy_model_input = torch.randn(
1, # Batch size
16, # VAE channels
latent_height,
latent_width,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)

packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
img_ids = flux_utils.prepare_partitioned_img_ids(1, latent_height, latent_width).to(device=accelerator.device)

print("img_ids: ", img_ids.shape)
else:
# VAE 8x compression
latent_height = height // 8
latent_width = width // 8
noisy_model_input = torch.randn(
1, # Batch size
16, # VAE channels
latent_height,
latent_width,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
latent_height, latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
img_ids = flux_utils.prepare_img_ids(1, latent_height // 2, latent_width // 2).to(device=accelerator.device)

assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids"

timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None

if controlnet_image is not None:
Expand All @@ -255,7 +286,7 @@ def encode_prompt(prpt):
with accelerator.autocast(), torch.no_grad():
x = denoise(
flux,
noise,
packed_noisy_model_input,
img_ids,
t5_out,
txt_ids,
Expand All @@ -268,7 +299,11 @@ def encode_prompt(prpt):
neg_cond=neg_cond,
)

x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
# unpack latents
if args.partitioned_vae:
x = flux_utils.unpack_latents(x, height // 32, width // 32)
else:
x = flux_utils.unpack_latents(x, latent_height // 2, latent_width // 2)

# latent to image
clean_memory_on_device(accelerator.device)
Expand Down Expand Up @@ -680,3 +715,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
parser.add_argument("--partitioned_vae", action="store_true", help="Partitioned VAE from Diffusion4k paper")
61 changes: 53 additions & 8 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,23 +346,68 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids

def prepare_partitioned_img_ids(batch_size: int, height: int, width: int):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]

def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)

return latent_image_ids

def unpack_latents(x: torch.FloatTensor, height: int, width: int) -> torch.FloatTensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=height, w=width, ph=2, pw=2)
return x


def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
# def unpack_latents(latents, height, width):
# batch_size, num_patches, channels = latents.shape
#
# # height = height // vae_scale_factor
# # width = width // vae_scale_factor
#
# latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
# latents = latents.permute(0, 3, 1, 4, 2, 5)
#
# latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
#
# return latents

def unpack_partitioned_latents(x, height, width):
x = einops.rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=height//2, # Divide by 2 because each patch is 2x2
w=width//2, # Divide by 2 because each patch is 2x2
ph=2,
pw=2
)
return x


# def pack_latents(x: torch.Tensor) -> torch.Tensor:
# """
# x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
# """
# x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
# return x
def pack_latents(latents):
batch_size, channels, height, width = latents.shape
latents = latents.view(batch_size, channels, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), channels * 4)

return latents



# region Diffusers

NUM_DOUBLE_BLOCKS = 19
Expand Down