Skip to content

Commit 8b36d90

Browse files
committed
feat: support block_to_swap for FLUX.1 ControlNet training
1 parent e369b9a commit 8b36d90

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ The command to install PyTorch is as follows:
1414

1515
### Recent Updates
1616

17+
18+
Dec 3, 2024:
19+
20+
-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training).
21+
1722
Dec 2, 2024:
1823

1924
- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details.
@@ -276,6 +281,14 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_tr
276281
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed
277282
```
278283

284+
For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested.
285+
```
286+
--blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk
287+
```
288+
289+
The training can be done with 16GB VRAM GPUs with around 30 blocks swapped.
290+
291+
`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used.
279292

280293
### FLUX.1 OFT training
281294

flux_train_control_net.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ def train(args):
119119
"datasets": [
120120
{
121121
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
122-
args.train_data_dir,
123-
args.conditioning_data_dir,
124-
args.caption_extension
122+
args.train_data_dir, args.conditioning_data_dir, args.caption_extension
125123
)
126124
}
127125
]
@@ -263,13 +261,17 @@ def train(args):
263261
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
264262
)
265263
flux.requires_grad_(False)
266-
flux.to(accelerator.device)
267264

268265
# load controlnet
269-
controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
266+
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
267+
controlnet = flux_utils.load_controlnet(
268+
args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
269+
)
270270
controlnet.train()
271271

272272
if args.gradient_checkpointing:
273+
if not args.deepspeed:
274+
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
273275
controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
274276

275277
# block swap
@@ -296,7 +298,11 @@ def train(args):
296298
# This idea is based on 2kpr's great work. Thank you!
297299
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
298300
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
299-
controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
301+
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
302+
# ControlNet only has two blocks, so we can keep it on GPU
303+
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
304+
else:
305+
flux.to(accelerator.device)
300306

301307
if not cache_latents:
302308
# load VAE here if not cached
@@ -455,9 +461,7 @@ def train(args):
455461
else:
456462
# accelerator does some magic
457463
# if we doesn't swap blocks, we can move the model to device
458-
controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks])
459-
if is_swapping_blocks:
460-
accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
464+
controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks])
461465
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
462466

463467
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
@@ -564,11 +568,13 @@ def grad_hook(parameter: torch.Tensor):
564568
)
565569

566570
if is_swapping_blocks:
567-
accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward()
571+
flux.prepare_block_swap_before_forward()
568572

569573
# For --sample_at_first
570574
optimizer_eval_fn()
571-
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet)
575+
flux_train_utils.sample_images(
576+
accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
577+
)
572578
optimizer_train_fn()
573579
if len(accelerator.trackers) > 0:
574580
# log empty object to commit the sample images to wandb
@@ -629,7 +635,11 @@ def grad_hook(parameter: torch.Tensor):
629635
# pack latents and get img_ids
630636
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
631637
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
632-
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype)
638+
img_ids = (
639+
flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width)
640+
.to(device=accelerator.device)
641+
.to(weight_dtype)
642+
)
633643

634644
# get guidance: ensure args.guidance_scale is float
635645
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype)
@@ -638,7 +648,7 @@ def grad_hook(parameter: torch.Tensor):
638648
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
639649
if not args.apply_t5_attn_mask:
640650
t5_attn_mask = None
641-
651+
642652
with accelerator.autocast():
643653
block_samples, block_single_samples = controlnet(
644654
img=packed_noisy_model_input,
@@ -715,7 +725,15 @@ def grad_hook(parameter: torch.Tensor):
715725

716726
optimizer_eval_fn()
717727
flux_train_utils.sample_images(
718-
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
728+
accelerator,
729+
args,
730+
None,
731+
global_step,
732+
flux,
733+
ae,
734+
[clip_l, t5xxl],
735+
sample_prompts_te_outputs,
736+
controlnet=controlnet,
719737
)
720738

721739
# 指定ステップごとにモデルを保存

0 commit comments

Comments
 (0)