Skip to content

Commit 20b93e8

Browse files
committed
More massive messy commits!
Add fancy new graph generation showing LR and loss values. Add LR and loss values to UI during updates. Fix UI layout for progress bar and textinfos. Remove hypernetwork junk from xattention, it's not applicable when training a model. Add get_scheduler method from unreleased diffusers version to allow for new LR params. Add wrapper class for training output - need to add for imagic yet. Remove use CPU option entirely, replace with "use lora" in wizard. Add grad acuumulataion steps to wizard. Add lr cycles and lr power UI params for applicable LR schedulers. Remove broke min_learning_rate param. Remove unused "save class text" param. Update js ui hints. Bump diffusers version. Make labels more useful, auto-adjusting as needed. Add manual "check progress" button to UI, because "gradio".
1 parent a5b520e commit 20b93e8

File tree

11 files changed

+450
-201
lines changed

11 files changed

+450
-201
lines changed

dreambooth/db_config.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,17 @@ def __init__(self,
3737
half_model: bool = False,
3838
has_ema: bool = False,
3939
hflip: bool = False,
40-
learning_rate: float = 0.00000172,
40+
learning_rate: float = 5e-6,
4141
lora_learning_rate: float = 1e-4,
4242
lora_txt_learning_rate: float = 5e-5,
4343
lora_txt_weight: float = 1.0,
4444
lora_weight: float = 1.0,
45+
lr_cycles: int = 1,
46+
lr_power: float = 1.0,
4547
lr_scheduler: str = 'constant',
4648
lr_warmup_steps: int = 0,
4749
max_token_length: int = 75,
4850
max_train_steps: int = 1000,
49-
min_learning_rate: float = .000001,
5051
mixed_precision: str = "fp16",
5152
model_path: str = "",
5253
not_cache_latents=False,
@@ -60,7 +61,6 @@ def __init__(self,
6061
save_ckpt_after: bool = True,
6162
save_ckpt_cancel: bool = True,
6263
save_ckpt_during: bool = True,
63-
save_class_txt: bool = False,
6464
save_embedding_every: int = 500,
6565
save_lora_after: bool = True,
6666
save_lora_cancel: bool = True,
@@ -79,7 +79,6 @@ def __init__(self,
7979
train_text_encoder: bool = True,
8080
use_8bit_adam: bool = True,
8181
use_concepts: bool = False,
82-
use_cpu: bool = False,
8382
use_ema: bool = True,
8483
use_lora: bool = False,
8584
v2: bool = False,
@@ -173,9 +172,10 @@ def __init__(self,
173172
self.lora_txt_learning_rate = lora_txt_learning_rate
174173
self.lora_txt_weight = lora_txt_weight
175174
self.lora_weight = lora_weight
175+
self.lr_cycles = lr_cycles
176+
self.lr_power = lr_power
176177
self.lr_scheduler = lr_scheduler
177178
self.lr_warmup_steps = lr_warmup_steps
178-
self.min_learning_rate = min_learning_rate
179179
self.max_token_length = max_token_length
180180
self.max_train_steps = max_train_steps
181181
self.mixed_precision = mixed_precision
@@ -193,7 +193,6 @@ def __init__(self,
193193
self.save_ckpt_after = save_ckpt_after
194194
self.save_ckpt_cancel = save_ckpt_cancel
195195
self.save_ckpt_during = save_ckpt_during
196-
self.save_class_txt = save_class_txt
197196
self.save_embedding_every = save_embedding_every
198197
self.save_lora_after = save_lora_after
199198
self.save_lora_cancel = save_lora_cancel
@@ -211,7 +210,6 @@ def __init__(self,
211210
self.train_text_encoder = train_text_encoder
212211
self.use_8bit_adam = use_8bit_adam
213212
self.use_concepts = use_concepts
214-
self.use_cpu = use_cpu
215213
self.use_ema = False if use_ema is None else use_ema
216214
self.use_lora = False if use_lora is None else use_lora
217215
if scheduler is not None:

dreambooth/train_dreambooth.py

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
import time
77
import traceback
88
from contextlib import nullcontext
9+
from decimal import Decimal
910
from pathlib import Path
10-
from typing import Optional, Union
11+
from typing import Optional, Union, List
1112

1213
import torch
1314
import torch.nn.functional as f
1415
import torch.utils.checkpoint
16+
from PIL import Image
1517
from accelerate import Accelerator
1618
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel, DDPMScheduler
17-
from diffusers.optimization import get_scheduler, get_polynomial_decay_schedule_with_warmup
1819
from diffusers.utils import logging as dl
1920
from huggingface_hub import HfFolder, whoami
2021
from torch.utils.data import Dataset
@@ -28,7 +29,8 @@
2829
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
2930
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import encode_hidden_state, \
3031
EMAModel, generate_classifiers
31-
from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, unload_system_models
32+
from extensions.sd_dreambooth_extension.dreambooth.utils import cleanup, unload_system_models, parse_logs
33+
from extensions.sd_dreambooth_extension.dreambooth.xattention import get_scheduler
3234
from extensions.sd_dreambooth_extension.lora_diffusion.lora import save_lora_weight, apply_lora_weights
3335
from extensions.sd_dreambooth_extension.scripts.dreambooth import printm
3436
from modules import shared, paths, images
@@ -127,9 +129,35 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
127129
return f"{organization}/{model_id}"
128130

129131

132+
last_samples = []
133+
134+
135+
class TrainResult:
136+
config: DreamboothConfig = None
137+
mem_record: List = []
138+
msg: str = ""
139+
samples: [Image] = []
140+
141+
130142
def main(args: DreamboothConfig, memory_record, use_subdir, lora_model=None, lora_alpha=1.0, lora_txt_alpha=1.0,
131-
custom_model_name="", use_txt2img=True) -> tuple[DreamboothConfig, dict, str]:
143+
custom_model_name="", use_txt2img=True) -> TrainResult:
144+
"""
145+
146+
@param args: The model config to use.
147+
@param memory_record: A global memory record. This can probably go away now.
148+
@param use_subdir: Save checkpoints to a subdirectory.
149+
@param lora_model: An optional lora model to use/resume.
150+
@param lora_alpha: The weight to use when applying lora unet.
151+
@param lora_txt_alpha: The weight to use when applying lora text encoder.
152+
@param custom_model_name: A custom name to use when saving checkpoints.
153+
@param use_txt2img: Use txt2img when generating class images.
154+
@return: TrainResult
155+
"""
156+
global last_samples
132157
logging_dir = Path(args.model_dir, "logging")
158+
result = TrainResult
159+
result.config = args
160+
133161
if profile_memory:
134162
cleanup(True)
135163
prof = profile(
@@ -171,16 +199,18 @@ def main(args: DreamboothConfig, memory_record, use_subdir, lora_model=None, lor
171199
gradient_accumulation_steps=args.gradient_accumulation_steps,
172200
mixed_precision=args.mixed_precision,
173201
log_with="tensorboard",
174-
logging_dir=logging_dir,
175-
cpu=args.use_cpu
202+
logging_dir=logging_dir
176203
)
177204
except Exception as e:
178205
if "AcceleratorState" in str(e):
179206
msg = "Change in precision detected, please restart the webUI entirely to use new precision."
180207
else:
181208
msg = f"Exception initializing accelerator: {e}"
182209
print(msg)
183-
return args, mem_record, msg
210+
result.msg = msg
211+
result.mem_record = mem_record
212+
result.config = args
213+
return result
184214
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
185215
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
186216
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
@@ -235,14 +265,14 @@ def main(args: DreamboothConfig, memory_record, use_subdir, lora_model=None, lor
235265
def create_vae():
236266
vae_path = args.pretrained_vae_name_or_path if args.pretrained_vae_name_or_path else \
237267
args.pretrained_model_name_or_path
238-
result = AutoencoderKL.from_pretrained(
268+
new_vae = AutoencoderKL.from_pretrained(
239269
vae_path,
240270
subfolder=None if args.pretrained_vae_name_or_path else "vae",
241271
revision=args.revision
242272
)
243-
result.requires_grad_(False)
244-
result.to(accelerator.device, dtype=weight_dtype)
245-
return result
273+
new_vae.requires_grad_(False)
274+
new_vae.to(accelerator.device, dtype=weight_dtype)
275+
return new_vae
246276

247277
vae = create_vae()
248278

@@ -354,7 +384,10 @@ def cleanup_memory():
354384
print(msg)
355385
status.textinfo = msg
356386
cleanup_memory()
357-
return args, mem_record, msg
387+
result.msg = msg
388+
result.mem_record = mem_record
389+
result.config = args
390+
return result
358391

359392
def collate_fn(examples):
360393
input_ids = [ex["instance_prompt_ids"] for ex in examples]
@@ -386,7 +419,7 @@ def collate_fn(examples):
386419

387420
train_dataloader = torch.utils.data.DataLoader(
388421
# train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
389-
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
422+
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=1
390423
)
391424
# Move text_encoder and VAE to GPU.
392425
# For mixed precision training we cast the text_encoder and vae weights to half-precision
@@ -452,22 +485,21 @@ def cache_latents(td=None, tdl=None, enc_vae=None, orig_dataset=None):
452485
if not args.not_cache_latents:
453486
train_dataset, train_dataloader = cache_latents(enc_vae=vae, orig_dataset=gen_dataset)
454487

455-
if args.lr_scheduler == "polynomial":
456-
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
457-
optimizer=optimizer,
458-
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
459-
num_training_steps=max_train_steps * args.gradient_accumulation_steps,
460-
lr_end=args.min_learning_rate,
461-
last_epoch=args.epoch
462-
)
463-
pass
464-
else:
465-
lr_scheduler = get_scheduler(
466-
args.lr_scheduler,
467-
optimizer=optimizer,
468-
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
469-
num_training_steps=max_train_steps * args.gradient_accumulation_steps,
470-
)
488+
# This needs to be done before we set up the optimizer, for reasons that should have been obvious.
489+
overrode_max_train_steps = False
490+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
491+
if max_train_steps is None or max_train_steps < 1:
492+
max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
493+
overrode_max_train_steps = True
494+
495+
lr_scheduler = get_scheduler(
496+
args.lr_scheduler,
497+
optimizer=optimizer,
498+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
499+
num_training_steps=max_train_steps * args.gradient_accumulation_steps,
500+
num_cycles=args.lr_cycles,
501+
power=args.lr_power,
502+
)
471503

472504
# create ema, fix OOM
473505
if args.use_ema:
@@ -494,11 +526,6 @@ def cache_latents(td=None, tdl=None, enc_vae=None, orig_dataset=None):
494526

495527
printm("Scheduler, EMA Loaded.")
496528
# Scheduler and math around the number of training steps.
497-
overrode_max_train_steps = False
498-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
499-
if max_train_steps is None or max_train_steps < 1:
500-
max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
501-
overrode_max_train_steps = True
502529

503530
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
504531
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -548,9 +575,9 @@ def cache_latents(td=None, tdl=None, enc_vae=None, orig_dataset=None):
548575
print(f" Resuming from checkpoint: {resume_from_checkpoint}")
549576
print(f" First resume epoch: {first_epoch}")
550577
print(f" First resume step: {resume_step}")
551-
print(f" CPU: {args.use_cpu} Adam: {use_adam}, Prec: {args.mixed_precision}")
552-
print(f" Grad: {args.gradient_checkpointing}, TextTr: {args.train_text_encoder} Use EMA: {args.use_ema}")
553-
print(f" Learning rate: {args.learning_rate} Use lora:{args.use_lora}")
578+
print(f" Lora: {args.use_lora}, Adam: {use_adam}, Prec: {args.mixed_precision}")
579+
print(f" Grad: {args.gradient_checkpointing}, Text: {args.train_text_encoder}, EMA: {args.use_ema}")
580+
print(f" LR: {args.learning_rate})")
554581

555582
last_img_step = -1
556583
last_save_step = -1
@@ -627,6 +654,7 @@ def check_save():
627654
return save_model
628655

629656
def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lora):
657+
global last_samples
630658
# Create the pipeline using the trained modules and save it.
631659
if accelerator.is_main_process:
632660
g_cuda = None
@@ -719,6 +747,7 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
719747
prompts = gen_dataset.get_sample_prompts()
720748
ci = 0
721749
samples = []
750+
last_samples = []
722751
for c in prompts:
723752
seed = c.seed
724753
if seed is None or seed == '' or seed == -1:
@@ -739,10 +768,15 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
739768
txt_file.write(c.prompt)
740769
s_image.save(image_name)
741770
ci += 1
771+
for sample in samples:
772+
last_samples.append(sample)
742773
if len(samples) > 1:
743774
img_grid = images.image_grid(samples)
744775
status.current_image = img_grid
745776
del samples
777+
log_images = parse_logs(model_name=args.model_name)
778+
for log_image in log_images:
779+
last_samples.insert(0, log_image)
746780

747781
except Exception as em:
748782
print(f"Exception with the stupid image again: {em}")
@@ -855,13 +889,14 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
855889
if args.use_ema and ema_unet is not None:
856890
ema_unet.step(unet.parameters())
857891

858-
if not global_step % 2:
892+
if not global_step % args.train_batch_size:
859893
allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)
860894
cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
861-
logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0],
862-
"vram": f"{allocated}/{cached}GB"}
863-
status.textinfo2 = f"loss: {loss_avg.avg.item()}, lr: {lr_scheduler.get_last_lr()[0]}" \
864-
f"vram: {allocated}/{cached}GB"
895+
log_loss = loss_avg.avg.item()
896+
last_lr = lr_scheduler.get_last_lr()[0]
897+
logs = {"loss": log_loss, "lr": last_lr, "vram_usage": f"{allocated}"}
898+
status.textinfo2 = f"Loss: {'%.2f' % log_loss}, LR: {'{:.2E}'.format(Decimal(last_lr))}, " \
899+
f"VRAM: {allocated}/{cached} GB"
865900
progress_bar.set_postfix(**logs)
866901
accelerator.log(logs, step=args.revision)
867902
loss_avg.reset()
@@ -884,7 +919,7 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
884919
training_complete = True
885920
tot_step = global_step + lifetime_step
886921
status.textinfo = f"Steps: {global_step}/{max_train_steps} (Current)," \
887-
f" {args.revision}/{tot_step} (Lifetime), Epoch: {args.epoch}"
922+
f" {args.revision}/{tot_step + args.lifetime_revision} (Lifetime), Epoch: {args.epoch}"
888923

889924
# Log completion message
890925
if training_complete:
@@ -945,4 +980,8 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
945980
prof.stop()
946981
cleanup_memory()
947982
accelerator.end_training()
948-
return args, mem_record, msg
983+
result.msg = msg
984+
result.config = args
985+
result.mem_record = mem_record
986+
result.samples = last_samples
987+
return result

0 commit comments

Comments
 (0)