forked from lucataco/serverless-template-dreambooth-training
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.sh
28 lines (27 loc) · 895 Bytes
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/bin/bash
export NUM_STEPS=1200
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="stable_diffusion_weights/"
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--seed=3434554 \
--resolution=512 \
--train_batch_size=1 \
--train_text_encoder \
--mixed_precision="fp16" \
--use_8bit_adam \
--gradient_accumulation_steps=1 \
--learning_rate=1e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=50 \
--sample_batch_size=1 \
--max_train_steps=$NUM_STEPS \
--save_interval=$NUM_STEPS \
--save_sample_prompt="photo of sks person" \
--concepts_list="concepts_list.json" \
--pad_tokens