Skip to content

Commit

Permalink
code commit
Browse files Browse the repository at this point in the history
  • Loading branch information
quattrinifabio committed Aug 8, 2024
1 parent fe98ac5 commit 21c880d
Show file tree
Hide file tree
Showing 16 changed files with 15,009 additions and 3 deletions.
130 changes: 130 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# VS Code
.vscode/


images/
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# Alfie

## Setup

```bash
conda create --name alfie python==3.11.7
conda activate alfie
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git
pip install git+https://github.com/SunzeY/AlphaCLIP.git

```

## Example usage:

```python
python generate_prompt.py --setting centering-rgba-alfie --fg_prompt 'A photo of a cat with a hat'

```



Code inspired by [DAAM](https://github.com/castorini/daam) and [DAAMI2I](https://github.com/RishiDarkDevil/daam-i2i)
Code inspired by [DAAM](https://github.com/castorini/daam) and [DAAMI2I](https://github.com/RishiDarkDevil/daam-i2i)
131 changes: 131 additions & 0 deletions generate_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from pathlib import Path
from settings import parse_setting
import json

from sam_aclip_pixart_sigma.generate import get_pipe, base_arg_parser, parse_bool_args

from transformers import VitMatteImageProcessor, VitMatteForImageMatting

import logging
from accelerate import PartialState
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from sam_aclip_pixart_sigma.grabcut import grabcut, save_rgba

import torch
from sam_aclip_pixart_sigma.trimap import compute_trimap
from sam_aclip_pixart_sigma.utils import normalize_masks

torch.backends.cuda.matmul.allow_tf32 = True

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = get_logger(__name__)


def main():
parser = base_arg_parser()
parser.add_argument("--setting_name", type=str, default='centering-rgba-alfie')
parser.add_argument("--fg_prompt", type=str, required=True)
args = parser.parse_args()
settings_dict = parse_setting(args.setting_name)
vars(args).update(settings_dict)
args = parse_bool_args(args)

distributed_state = PartialState()
args.device = distributed_state.device

args.save_folder = args.save_folder / 'prompts'
args.save_folder.mkdir(parents=True, exist_ok=True)

pipe = get_pipe(
image_size=args.image_size,
scheduler=args.scheduler,
device=args.device)

suffix = ' on a white background'
prompt_complete = ["A white background", args.fg_prompt]
prompt_full = ' '.join(prompt_complete[1].split())
negative_prompt = ["Blurry, shadow, low-resolution, low-quality"] if args.use_neg_prompt else None
prompt = prompt_complete if args.centering else prompt_complete[1]
if args.use_suffix:
prompt += suffix

if args.cutout_model == 'vit-matte':
vit_matte_processor = VitMatteImageProcessor.from_pretrained(args.vit_matte_key)
vit_matte_model = VitMatteForImageMatting.from_pretrained(args.vit_matte_key)
vit_matte_model = vit_matte_model.eval()

base_name = '_'.join([
prompt_full,
'centering' if args.centering else '',
'sz_256' if args.image_size == 256 else 'sz_512'
])

config = vars(args).copy()
del config['device']
del config['save_folder']
del config['seed']
del config['num_images']
with open(args.save_folder / f'{base_name}.json', 'w') as f:
json.dump(config, f, indent=4)

for seed in range(args.seed, args.seed + args.num_images):
set_seed(seed)
generator = torch.Generator(device="cuda").manual_seed(seed)
name = f'{base_name}_seed_{seed}'

images, heatmaps = pipe(
prompt=prompt, negative_prompt=negative_prompt, nouns_to_exclude=args.nouns_to_exclude,
keep_cross_attention_maps=True, return_dict=False, num_inference_steps=args.steps,
centering=args.centering, generator=generator)

image = images[0]
rgb_image_filename = Path(args.save_folder / f"{name}.png")
if not rgb_image_filename.exists():
rgb_image_filename.parent.mkdir(parents=True, exist_ok=True)
image.save(rgb_image_filename)

torch.cuda.empty_cache()

if args.cutout_model == 'grabcut':
alpha_mask = grabcut(
image=image, attention_maps=list(heatmaps['cross_heatmaps_fg_nouns'].values()), image_size=args.image_size,
sure_fg_threshold=args.sure_fg_threshold, maybe_fg_threshold=args.maybe_fg_threshold,
maybe_bg_threshold=args.maybe_bg_threshold)

alfie_rgba_image_filename = Path(args.save_folder / f"{name}-rgba-alfie.png")
alfie_rgba_image_filename.parent.mkdir(parents=True, exist_ok=True)
alpha_mask_alfie = torch.tensor(alpha_mask)
alpha_mask_alfie = torch.where(alpha_mask_alfie == 1, normalize_masks(heatmaps['ff_heatmap'] + 1 * heatmaps['cross_heatmap_fg']), 0.)
save_rgba(image, alpha_mask_alfie, alfie_rgba_image_filename)

elif args.cutout_model == 'vit-matte':
trimap = compute_trimap(attention_maps=[list(heatmaps['cross_heatmaps_fg_nouns'].values())],
image_size=args.image_size,
sure_fg_threshold=args.sure_fg_threshold,
maybe_bg_threshold=args.maybe_bg_threshold)

vit_matte_inputs = vit_matte_processor(images=image, trimaps=trimap, return_tensors="pt").to(args.device)
vit_matte_model = vit_matte_model.to(args.device)
with torch.no_grad():
alpha_mask = vit_matte_model(**vit_matte_inputs).alphas[0, 0]
alpha_mask = 1 - alpha_mask.cpu().numpy()
save_rgba(image, alpha_mask, args.save_folder / f"{name}-rgba-vit_matte.png")
else:
raise ValueError(f'Invalid cutout model: {args.cutout_model}')



del heatmaps
torch.cuda.empty_cache()

logger.info("***** Done *****")


if __name__ == '__main__':
main()
15 changes: 15 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
einops~=0.7.0
nltk~=3.8.1
accelerate~=0.33.0
diffusers~=0.29.2
transformers~=4.43.3
tqdm~=4.66.2
pillow~=10.2.0
SentencePiece~=0.2.0
ftfy~=6.2.0
beautifulsoup4~=4.12.3
opencv-contrib-python~=4.9.0.80
scikit-image~=0.23.1
matplotlib~=3.8.4
loralib~=0.1.2
spacy~=3.7.5
1 change: 1 addition & 0 deletions sam_aclip_pixart_sigma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .transformer_2d import Transformer2DModel
Loading

0 comments on commit 21c880d

Please sign in to comment.