Skip to content

Add Option to use Target Model in LCM-LoRA Scripts #6537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 65 additions & 8 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def train_dataloader(self):
return self._train_dataloader


def log_validation(vae, unet, args, accelerator, weight_dtype, step):
def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="lora"):
logger.info("Running validation... ")

unet = accelerator.unwrap_model(unet)
Expand Down Expand Up @@ -306,7 +306,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)

tracker.log({"validation": formatted_images})
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")

Expand Down Expand Up @@ -677,6 +677,24 @@ def parse_args():
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
)
# ----Latent Consistency Distillation (LCD) Specific Arguments----
parser.add_argument(
"--use_target_model",
action="store_true",
help=(
"Whether to use a target model in addition to the U-Net student model. Using a target model can help"
" improve training stability at the cost of more GPU memory usage."
),
)
parser.add_argument(
"--ema_decay",
type=float,
default=0.95,
required=False,
help=(
"The exponential moving average (EMA) rate or decay factor to be used for updating the target model"
" parameters (if using a target model).",
),
)
parser.add_argument(
"--w_min",
type=float,
Expand Down Expand Up @@ -961,6 +979,12 @@ def main(args):
)
unet.train()

if args.use_target_model:
target_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
target_unet.train()

# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
Expand Down Expand Up @@ -1000,6 +1024,10 @@ def main(args):
)
unet = get_peft_model(unet, lora_config)

if args.use_target_model:
target_unet = get_peft_model(target_unet, lora_config)
target_unet.requires_grad_(False)

# 9. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
Expand All @@ -1020,6 +1048,8 @@ def main(args):
teacher_unet.to(accelerator.device)
if args.cast_teacher_unet:
teacher_unet.to(dtype=weight_dtype)
if args.use_target_model:
target_unet.to(accelerator.device)

# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
Expand All @@ -1039,14 +1069,25 @@ def save_model_hook(models, weights, output_dir):
# save weights in peft format to be able to load them back
unet_.save_pretrained(output_dir)

if args.use_target_model:
target_lora_state_dict = get_peft_model_state_dict(target_unet, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(
os.path.join(output_dir, "unet_target_lora"), target_lora_state_dict
)
# save weights in peft format to be able to load them back
target_unet.save_pretrained(output_dir)

for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
# load the LoRA into the model
unet_ = accelerator.unwrap_model(unet)
unet_.load_adapter(input_dir, "default", is_trainable=True)
unet_.load_adapter(os.path.join(input_dir, "unet_lora"), "default", is_trainable=True)

if args.use_target_model:
target_unet.load_adapter(os.path.join(input_dir, "unet_target_lora"), "default", is_trainable=True)

for _ in range(len(models)):
# pop models so that they are not loaded again
Expand All @@ -1067,7 +1108,8 @@ def load_model_hook(models, input_dir):
)
unet.enable_xformers_memory_efficient_attention()
teacher_unet.enable_xformers_memory_efficient_attention()
# target_unet.enable_xformers_memory_efficient_attention()
if args.use_target_unet:
target_unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

Expand Down Expand Up @@ -1351,10 +1393,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
x_prev = solver.ddim_step(pred_x0, pred_noise, index)

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
if args.use_target_model:
target_model = target_unet
else:
target_model = unet
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = unet(
target_noise_pred = target_model(
x_prev.float(),
timesteps,
timestep_cond=None,
Expand Down Expand Up @@ -1388,6 +1433,9 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# 12. If using a target model, update its parameters via EMA.
if args.use_target_model:
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
progress_bar.update(1)
global_step += 1

Expand Down Expand Up @@ -1418,7 +1466,9 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
logger.info(f"Saved state to {save_path}")

if global_step % args.validation_steps == 0:
log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
if args.use_target_model:
log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
Expand All @@ -1431,10 +1481,17 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet.save_pretrained(args.output_dir)
unet.save_pretrained(os.path.join(args.output_dir, "unet"))
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)

if args.use_target_model:
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
target_lora_state_dict = get_peft_model_state_dict(target_unet, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(
os.path.join(args.output_dir, "unet_target_lora"), target_lora_state_dict
)

if args.push_to_hub:
upload_folder(
repo_id=repo_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def train_dataloader(self):
return self._train_dataloader


def log_validation(vae, unet, args, accelerator, weight_dtype, step):
def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="lora"):
logger.info("Running validation... ")

unet = accelerator.unwrap_model(unet)
Expand Down Expand Up @@ -323,7 +323,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)

tracker.log({"validation": formatted_images})
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")

Expand Down Expand Up @@ -422,6 +422,20 @@ def ddim_step(self, pred_x0, pred_noise, timestep_index):
return x_prev


@torch.no_grad()
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.

:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
Expand Down Expand Up @@ -657,6 +671,24 @@ def parse_args():
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
)
# ----Latent Consistency Distillation (LCD) Specific Arguments----
parser.add_argument(
"--use_target_model",
action="store_true",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this default to false so existing users are not surprised?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target model will be used only if the --use_target_model flag is specified (so existing script calls should work as before).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. My bad.

help=(
"Whether to use a target model in addition to the U-Net student model. Using a target model can help"
" improve training stability at the cost of more GPU memory usage."
),
)
parser.add_argument(
"--ema_decay",
type=float,
default=0.95,
required=False,
help=(
"The exponential moving average (EMA) rate or decay factor to be used for updating the target model"
" parameters (if using a target model).",
),
)
parser.add_argument(
"--w_min",
type=float,
Expand Down Expand Up @@ -975,6 +1007,12 @@ def main(args):
)
unet.train()

if args.use_target_model:
target_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
target_unet.train()

# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
Expand Down Expand Up @@ -1014,6 +1052,10 @@ def main(args):
)
unet = get_peft_model(unet, lora_config)

if args.use_target_model:
target_unet = get_peft_model(target_unet, lora_config)
target_unet.requires_grad_(False)

# 9. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
Expand All @@ -1035,6 +1077,8 @@ def main(args):
teacher_unet.to(accelerator.device)
if args.cast_teacher_unet:
teacher_unet.to(dtype=weight_dtype)
if args.use_target_model:
target_unet.to(accelerator.device)

# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
Expand All @@ -1054,14 +1098,25 @@ def save_model_hook(models, weights, output_dir):
# save weights in peft format to be able to load them back
unet_.save_pretrained(output_dir)

if args.use_target_model:
target_lora_state_dict = get_peft_model_state_dict(target_unet, adapter_name="default")
StableDiffusionXLPipeline.save_lora_weights(
os.path.join(output_dir, "unet_target_lora"), target_lora_state_dict
)
# save weights in peft format to be able to load them back
target_unet.save_pretrained(output_dir)

for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
# load the LoRA into the model
unet_ = accelerator.unwrap_model(unet)
unet_.load_adapter(input_dir, "default", is_trainable=True)
unet_.load_adapter(os.path.join(input_dir, "unet_lora"), "default", is_trainable=True)

if args.use_target_model:
target_unet.load_adapter(os.path.join(input_dir, "unet_target_lora"), "default", is_trainable=True)

for _ in range(len(models)):
# pop models so that they are not loaded again
Expand All @@ -1082,7 +1137,8 @@ def load_model_hook(models, input_dir):
)
unet.enable_xformers_memory_efficient_attention()
teacher_unet.enable_xformers_memory_efficient_attention()
# target_unet.enable_xformers_memory_efficient_attention()
if args.use_target_unet:
target_unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

Expand Down Expand Up @@ -1408,10 +1464,13 @@ def compute_embeddings(
x_prev = solver.ddim_step(pred_x0, pred_noise, index)

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
if args.use_target_model:
target_model = target_unet
else:
target_model = unet
with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
target_noise_pred = unet(
target_noise_pred = target_model(
x_prev.float(),
timesteps,
timestep_cond=None,
Expand Down Expand Up @@ -1446,6 +1505,9 @@ def compute_embeddings(

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# 12. If using a target model, update its parameters via EMA.
if args.use_target_model:
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
progress_bar.update(1)
global_step += 1

Expand Down Expand Up @@ -1476,7 +1538,9 @@ def compute_embeddings(
logger.info(f"Saved state to {save_path}")

if global_step % args.validation_steps == 0:
log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
if args.use_target_model:
log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
Expand All @@ -1489,10 +1553,17 @@ def compute_embeddings(
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet.save_pretrained(args.output_dir)
unet.save_pretrained(os.path.join(args.output_dir, "unet"))
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)

if args.use_target_model:
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
target_lora_state_dict = get_peft_model_state_dict(target_unet, adapter_name="default")
StableDiffusionXLPipeline.save_lora_weights(
os.path.join(args.output_dir, "unet_target_lora"), target_lora_state_dict
)

if args.push_to_hub:
upload_folder(
repo_id=repo_id,
Expand Down