6
6
import time
7
7
import traceback
8
8
from contextlib import nullcontext
9
+ from decimal import Decimal
9
10
from pathlib import Path
10
- from typing import Optional , Union
11
+ from typing import Optional , Union , List
11
12
12
13
import torch
13
14
import torch .nn .functional as f
14
15
import torch .utils .checkpoint
16
+ from PIL import Image
15
17
from accelerate import Accelerator
16
18
from diffusers import AutoencoderKL , DDIMScheduler , DiffusionPipeline , UNet2DConditionModel , DDPMScheduler
17
- from diffusers .optimization import get_scheduler , get_polynomial_decay_schedule_with_warmup
18
19
from diffusers .utils import logging as dl
19
20
from huggingface_hub import HfFolder , whoami
20
21
from torch .utils .data import Dataset
28
29
from extensions .sd_dreambooth_extension .dreambooth .diff_to_sd import compile_checkpoint
29
30
from extensions .sd_dreambooth_extension .dreambooth .finetune_utils import encode_hidden_state , \
30
31
EMAModel , generate_classifiers
31
- from extensions .sd_dreambooth_extension .dreambooth .utils import cleanup , unload_system_models
32
+ from extensions .sd_dreambooth_extension .dreambooth .utils import cleanup , unload_system_models , parse_logs
33
+ from extensions .sd_dreambooth_extension .dreambooth .xattention import get_scheduler
32
34
from extensions .sd_dreambooth_extension .lora_diffusion .lora import save_lora_weight , apply_lora_weights
33
35
from extensions .sd_dreambooth_extension .scripts .dreambooth import printm
34
36
from modules import shared , paths , images
@@ -127,9 +129,35 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
127
129
return f"{ organization } /{ model_id } "
128
130
129
131
132
+ last_samples = []
133
+
134
+
135
+ class TrainResult :
136
+ config : DreamboothConfig = None
137
+ mem_record : List = []
138
+ msg : str = ""
139
+ samples : [Image ] = []
140
+
141
+
130
142
def main (args : DreamboothConfig , memory_record , use_subdir , lora_model = None , lora_alpha = 1.0 , lora_txt_alpha = 1.0 ,
131
- custom_model_name = "" , use_txt2img = True ) -> tuple [DreamboothConfig , dict , str ]:
143
+ custom_model_name = "" , use_txt2img = True ) -> TrainResult :
144
+ """
145
+
146
+ @param args: The model config to use.
147
+ @param memory_record: A global memory record. This can probably go away now.
148
+ @param use_subdir: Save checkpoints to a subdirectory.
149
+ @param lora_model: An optional lora model to use/resume.
150
+ @param lora_alpha: The weight to use when applying lora unet.
151
+ @param lora_txt_alpha: The weight to use when applying lora text encoder.
152
+ @param custom_model_name: A custom name to use when saving checkpoints.
153
+ @param use_txt2img: Use txt2img when generating class images.
154
+ @return: TrainResult
155
+ """
156
+ global last_samples
132
157
logging_dir = Path (args .model_dir , "logging" )
158
+ result = TrainResult
159
+ result .config = args
160
+
133
161
if profile_memory :
134
162
cleanup (True )
135
163
prof = profile (
@@ -171,16 +199,18 @@ def main(args: DreamboothConfig, memory_record, use_subdir, lora_model=None, lor
171
199
gradient_accumulation_steps = args .gradient_accumulation_steps ,
172
200
mixed_precision = args .mixed_precision ,
173
201
log_with = "tensorboard" ,
174
- logging_dir = logging_dir ,
175
- cpu = args .use_cpu
202
+ logging_dir = logging_dir
176
203
)
177
204
except Exception as e :
178
205
if "AcceleratorState" in str (e ):
179
206
msg = "Change in precision detected, please restart the webUI entirely to use new precision."
180
207
else :
181
208
msg = f"Exception initializing accelerator: { e } "
182
209
print (msg )
183
- return args , mem_record , msg
210
+ result .msg = msg
211
+ result .mem_record = mem_record
212
+ result .config = args
213
+ return result
184
214
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
185
215
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
186
216
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
@@ -235,14 +265,14 @@ def main(args: DreamboothConfig, memory_record, use_subdir, lora_model=None, lor
235
265
def create_vae ():
236
266
vae_path = args .pretrained_vae_name_or_path if args .pretrained_vae_name_or_path else \
237
267
args .pretrained_model_name_or_path
238
- result = AutoencoderKL .from_pretrained (
268
+ new_vae = AutoencoderKL .from_pretrained (
239
269
vae_path ,
240
270
subfolder = None if args .pretrained_vae_name_or_path else "vae" ,
241
271
revision = args .revision
242
272
)
243
- result .requires_grad_ (False )
244
- result .to (accelerator .device , dtype = weight_dtype )
245
- return result
273
+ new_vae .requires_grad_ (False )
274
+ new_vae .to (accelerator .device , dtype = weight_dtype )
275
+ return new_vae
246
276
247
277
vae = create_vae ()
248
278
@@ -354,7 +384,10 @@ def cleanup_memory():
354
384
print (msg )
355
385
status .textinfo = msg
356
386
cleanup_memory ()
357
- return args , mem_record , msg
387
+ result .msg = msg
388
+ result .mem_record = mem_record
389
+ result .config = args
390
+ return result
358
391
359
392
def collate_fn (examples ):
360
393
input_ids = [ex ["instance_prompt_ids" ] for ex in examples ]
@@ -386,7 +419,7 @@ def collate_fn(examples):
386
419
387
420
train_dataloader = torch .utils .data .DataLoader (
388
421
# train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
389
- train_dataset , batch_size = args .train_batch_size , shuffle = True , collate_fn = collate_fn , pin_memory = True
422
+ train_dataset , batch_size = args .train_batch_size , shuffle = True , collate_fn = collate_fn , pin_memory = True , num_workers = 1
390
423
)
391
424
# Move text_encoder and VAE to GPU.
392
425
# For mixed precision training we cast the text_encoder and vae weights to half-precision
@@ -452,22 +485,21 @@ def cache_latents(td=None, tdl=None, enc_vae=None, orig_dataset=None):
452
485
if not args .not_cache_latents :
453
486
train_dataset , train_dataloader = cache_latents (enc_vae = vae , orig_dataset = gen_dataset )
454
487
455
- if args .lr_scheduler == "polynomial" :
456
- lr_scheduler = get_polynomial_decay_schedule_with_warmup (
457
- optimizer = optimizer ,
458
- num_warmup_steps = args .lr_warmup_steps * args .gradient_accumulation_steps ,
459
- num_training_steps = max_train_steps * args .gradient_accumulation_steps ,
460
- lr_end = args .min_learning_rate ,
461
- last_epoch = args .epoch
462
- )
463
- pass
464
- else :
465
- lr_scheduler = get_scheduler (
466
- args .lr_scheduler ,
467
- optimizer = optimizer ,
468
- num_warmup_steps = args .lr_warmup_steps * args .gradient_accumulation_steps ,
469
- num_training_steps = max_train_steps * args .gradient_accumulation_steps ,
470
- )
488
+ # This needs to be done before we set up the optimizer, for reasons that should have been obvious.
489
+ overrode_max_train_steps = False
490
+ num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
491
+ if max_train_steps is None or max_train_steps < 1 :
492
+ max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
493
+ overrode_max_train_steps = True
494
+
495
+ lr_scheduler = get_scheduler (
496
+ args .lr_scheduler ,
497
+ optimizer = optimizer ,
498
+ num_warmup_steps = args .lr_warmup_steps * args .gradient_accumulation_steps ,
499
+ num_training_steps = max_train_steps * args .gradient_accumulation_steps ,
500
+ num_cycles = args .lr_cycles ,
501
+ power = args .lr_power ,
502
+ )
471
503
472
504
# create ema, fix OOM
473
505
if args .use_ema :
@@ -494,11 +526,6 @@ def cache_latents(td=None, tdl=None, enc_vae=None, orig_dataset=None):
494
526
495
527
printm ("Scheduler, EMA Loaded." )
496
528
# Scheduler and math around the number of training steps.
497
- overrode_max_train_steps = False
498
- num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
499
- if max_train_steps is None or max_train_steps < 1 :
500
- max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
501
- overrode_max_train_steps = True
502
529
503
530
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
504
531
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -548,9 +575,9 @@ def cache_latents(td=None, tdl=None, enc_vae=None, orig_dataset=None):
548
575
print (f" Resuming from checkpoint: { resume_from_checkpoint } " )
549
576
print (f" First resume epoch: { first_epoch } " )
550
577
print (f" First resume step: { resume_step } " )
551
- print (f" CPU : { args .use_cpu } Adam: { use_adam } , Prec: { args .mixed_precision } " )
552
- print (f" Grad: { args .gradient_checkpointing } , TextTr : { args .train_text_encoder } Use EMA: { args .use_ema } " )
553
- print (f" Learning rate : { args .learning_rate } Use lora: { args . use_lora } " )
578
+ print (f" Lora : { args .use_lora } , Adam: { use_adam } , Prec: { args .mixed_precision } " )
579
+ print (f" Grad: { args .gradient_checkpointing } , Text : { args .train_text_encoder } , EMA: { args .use_ema } " )
580
+ print (f" LR : { args .learning_rate } ) " )
554
581
555
582
last_img_step = - 1
556
583
last_save_step = - 1
@@ -627,6 +654,7 @@ def check_save():
627
654
return save_model
628
655
629
656
def save_weights (save_image , save_model , save_snapshot , save_checkpoint , save_lora ):
657
+ global last_samples
630
658
# Create the pipeline using the trained modules and save it.
631
659
if accelerator .is_main_process :
632
660
g_cuda = None
@@ -719,6 +747,7 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
719
747
prompts = gen_dataset .get_sample_prompts ()
720
748
ci = 0
721
749
samples = []
750
+ last_samples = []
722
751
for c in prompts :
723
752
seed = c .seed
724
753
if seed is None or seed == '' or seed == - 1 :
@@ -739,10 +768,15 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
739
768
txt_file .write (c .prompt )
740
769
s_image .save (image_name )
741
770
ci += 1
771
+ for sample in samples :
772
+ last_samples .append (sample )
742
773
if len (samples ) > 1 :
743
774
img_grid = images .image_grid (samples )
744
775
status .current_image = img_grid
745
776
del samples
777
+ log_images = parse_logs (model_name = args .model_name )
778
+ for log_image in log_images :
779
+ last_samples .insert (0 , log_image )
746
780
747
781
except Exception as em :
748
782
print (f"Exception with the stupid image again: { em } " )
@@ -855,13 +889,14 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
855
889
if args .use_ema and ema_unet is not None :
856
890
ema_unet .step (unet .parameters ())
857
891
858
- if not global_step % 2 :
892
+ if not global_step % args . train_batch_size :
859
893
allocated = round (torch .cuda .memory_allocated (0 ) / 1024 ** 3 , 1 )
860
894
cached = round (torch .cuda .memory_reserved (0 ) / 1024 ** 3 , 1 )
861
- logs = {"loss" : loss_avg .avg .item (), "lr" : lr_scheduler .get_last_lr ()[0 ],
862
- "vram" : f"{ allocated } /{ cached } GB" }
863
- status .textinfo2 = f"loss: { loss_avg .avg .item ()} , lr: { lr_scheduler .get_last_lr ()[0 ]} " \
864
- f"vram: { allocated } /{ cached } GB"
895
+ log_loss = loss_avg .avg .item ()
896
+ last_lr = lr_scheduler .get_last_lr ()[0 ]
897
+ logs = {"loss" : log_loss , "lr" : last_lr , "vram_usage" : f"{ allocated } " }
898
+ status .textinfo2 = f"Loss: { '%.2f' % log_loss } , LR: { '{:.2E}' .format (Decimal (last_lr ))} , " \
899
+ f"VRAM: { allocated } /{ cached } GB"
865
900
progress_bar .set_postfix (** logs )
866
901
accelerator .log (logs , step = args .revision )
867
902
loss_avg .reset ()
@@ -884,7 +919,7 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
884
919
training_complete = True
885
920
tot_step = global_step + lifetime_step
886
921
status .textinfo = f"Steps: { global_step } /{ max_train_steps } (Current)," \
887
- f" { args .revision } /{ tot_step } (Lifetime), Epoch: { args .epoch } "
922
+ f" { args .revision } /{ tot_step + args . lifetime_revision } (Lifetime), Epoch: { args .epoch } "
888
923
889
924
# Log completion message
890
925
if training_complete :
@@ -945,4 +980,8 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
945
980
prof .stop ()
946
981
cleanup_memory ()
947
982
accelerator .end_training ()
948
- return args , mem_record , msg
983
+ result .msg = msg
984
+ result .config = args
985
+ result .mem_record = mem_record
986
+ result .samples = last_samples
987
+ return result
0 commit comments