From 102d6eac8bcbe834111fd488f1bd77f2506b73fe Mon Sep 17 00:00:00 2001 From: quattrinifabio Date: Thu, 8 Aug 2024 23:28:07 +0200 Subject: [PATCH] updates --- README.md | 1 + alfie/generate.py | 1 + generate_prompt.py | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5a02c3f..4f057ff 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ pip install -r requirements.txt ```python python generate_prompt.py --setting centering-rgba-alfie --fg_prompt 'A photo of a cat with a hat' +python generate_prompt.py --setting centering-rgba-alfie --fg_prompt 'A large, colorful tree made of money, with lots of yellow and white coins hanging from its branches' ``` diff --git a/alfie/generate.py b/alfie/generate.py index 2dc693c..48382f4 100644 --- a/alfie/generate.py +++ b/alfie/generate.py @@ -94,6 +94,7 @@ def get_pipe(image_size, scheduler, device): def base_arg_parser(): parser = argparse.ArgumentParser() parser.add_argument("--centering", type=str, default='True') + parser.add_argument("--k", type=int, default=1) parser.add_argument("--resample", type=str, default='True') parser.add_argument("--scheduler", type=str, default='euler', choices=['euler', 'euler_ancestral']) parser.add_argument("--use_neg_prompt", type=str, default='True') diff --git a/generate_prompt.py b/generate_prompt.py index 792a400..17e883b 100644 --- a/generate_prompt.py +++ b/generate_prompt.py @@ -101,7 +101,9 @@ def main(): alfie_rgba_image_filename = Path(args.save_folder / f"{name}-rgba-alfie.png") alfie_rgba_image_filename.parent.mkdir(parents=True, exist_ok=True) alpha_mask_alfie = torch.tensor(alpha_mask) - alpha_mask_alfie = torch.where(alpha_mask_alfie == 1, normalize_masks(heatmaps['ff_heatmap'] + 1 * heatmaps['cross_heatmap_fg']), 0.) + alfa_hat = normalize_masks(heatmaps['ff_heatmap'] + 1 * heatmaps['cross_heatmap_fg']) + alfa_hat = (alfa_hat + args.k * alfa_hat).clip(0, 1) + alpha_mask_alfie = torch.where(alpha_mask_alfie == 1, alfa_hat, 0.) save_rgba(image, alpha_mask_alfie, alfie_rgba_image_filename) elif args.cutout_model == 'vit-matte':