Skip to content

Commit 7bfc1ee

Browse files
WenheLIsayakpaullinoytsaban
authored
fix the LR schedulers for dreambooth_lora (#8510)
* update training * update --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]>
1 parent 71c0461 commit 7bfc1ee

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,17 +1524,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15241524
torch.cuda.empty_cache()
15251525

15261526
# Scheduler and math around the number of training steps.
1527-
overrode_max_train_steps = False
1528-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1527+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1528+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
15291529
if args.max_train_steps is None:
1530-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1531-
overrode_max_train_steps = True
1530+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1531+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1532+
num_training_steps_for_scheduler = (
1533+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1534+
)
1535+
else:
1536+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
15321537

15331538
lr_scheduler = get_scheduler(
15341539
args.lr_scheduler,
15351540
optimizer=optimizer,
1536-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1537-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1541+
num_warmup_steps=num_warmup_steps_for_scheduler,
1542+
num_training_steps=num_training_steps_for_scheduler,
15381543
num_cycles=args.lr_num_cycles,
15391544
power=args.lr_power,
15401545
)
@@ -1551,8 +1556,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15511556

15521557
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
15531558
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1554-
if overrode_max_train_steps:
1559+
if args.max_train_steps is None:
15551560
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1561+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1562+
logger.warning(
1563+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1564+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1565+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1566+
)
15561567
# Afterwards we recalculate our number of training epochs
15571568
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
15581569

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,17 +1899,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18991899
torch.cuda.empty_cache()
19001900

19011901
# Scheduler and math around the number of training steps.
1902-
overrode_max_train_steps = False
1903-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1902+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1903+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
19041904
if args.max_train_steps is None:
1905-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1906-
overrode_max_train_steps = True
1905+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1906+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1907+
num_training_steps_for_scheduler = (
1908+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1909+
)
1910+
else:
1911+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
19071912

19081913
lr_scheduler = get_scheduler(
19091914
args.lr_scheduler,
19101915
optimizer=optimizer,
1911-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1912-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1916+
num_warmup_steps=num_warmup_steps_for_scheduler,
1917+
num_training_steps=num_training_steps_for_scheduler,
19131918
num_cycles=args.lr_num_cycles,
19141919
power=args.lr_power,
19151920
)
@@ -1926,8 +1931,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19261931

19271932
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
19281933
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1929-
if overrode_max_train_steps:
1934+
if args.max_train_steps is None:
19301935
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1936+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1937+
logger.warning(
1938+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1939+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1940+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1941+
)
19311942
# Afterwards we recalculate our number of training epochs
19321943
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
19331944

0 commit comments

Comments
 (0)