You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
it means we should repeat prompt_embeds many time for every picture.
but in collect_fn:
defcollate_fn(examples, with_prior_preservation=False):
pixel_values= [example["instance_images"] forexampleinexamples]
prompts= [example["instance_prompt"] forexampleinexamples]
original_sizes= [example["original_size"] forexampleinexamples]
crop_top_lefts= [example["crop_top_left"] forexampleinexamples]
# Concat class and instance examples for prior preservation.# We do this to avoid doing two forward passes.ifwith_prior_preservation:
pixel_values+= [example["class_images"] forexampleinexamples]
prompts+= [example["class_prompt"] forexampleinexamples]
original_sizes+= [example["original_size"] forexampleinexamples]
crop_top_lefts+= [example["crop_top_left"] forexampleinexamples]
pixel_values=torch.stack(pixel_values)
pixel_values=pixel_values.to(memory_format=torch.contiguous_format).float()
batch= {
"pixel_values": pixel_values,
"prompts": prompts,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
returnbatch
you can see class_images are directly append to the batch.
but when we have no train_dataset.custom_instance_prompts provided, the prompt_embeds like:
ifnottrain_dataset.custom_instance_prompts:
ifnotargs.train_text_encoder:
prompt_embeds=instance_prompt_hidden_statesunet_add_text_embeds=instance_pooled_prompt_embedsifargs.with_prior_preservation:
prompt_embeds=torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds=torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the# batch prompts on all training stepselse:
tokens_one=tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two=tokenize_prompt(tokenizer_two, args.instance_prompt)
ifargs.with_prior_preservation:
class_tokens_one=tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two=tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one=torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two=torch.cat([tokens_two, class_tokens_two], dim=0)
they seems to be[ins_token, cls_token] or [ins_embed, cls_embed]
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
in
train_dreambooth_lora_sdxl.py
you can see those codes:
it means we should repeat prompt_embeds many time for every picture.
but in collect_fn:
you can see class_images are directly append to the batch.
but when we have no train_dataset.custom_instance_prompts provided, the prompt_embeds like:
they seems to be[ins_token, cls_token] or [ins_embed, cls_embed]
so back to the code like
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
are they wrong? because if you use
repeat
, the embed will like: prompt_embeds_input = [ins_embed, cls_embed, ins_embed, cls_embed, ....]but i think it should be [ins_embed, ins_embed, ....cls_embed, cls_embed]
Beta Was this translation helpful? Give feedback.
All reactions