Skip to content

Commit 4490038

Browse files
committed
feat: condition text info token size up #4
1 parent 9a5a413 commit 4490038

File tree

10 files changed

+85
-211
lines changed

10 files changed

+85
-211
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,7 @@ temp/
2525

2626
# DEV
2727
debug.py
28-
*.ipynb
28+
*.ipynb
29+
30+
static/
31+
lpw_stable_diffusion/

README.md

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,18 @@ outputs:
8787
>> ./core/settings/settings.py
8888
```
8989

90-
| Name | Default | Desc |
91-
| ------------------ | ----------------------------- | ------------------------------------------------- |
92-
| MODEL_ID | CompVis/stable-diffusion-v1-4 | tagger embedding model part |
93-
| CUDA_DEVICE | "cpu" | target cuda device |
94-
| CUDA_DEVICES | [0] | visible cuda device |
95-
| MB_BATCH_SIZE | 1 | Micro Batch: MAX Batch size |
96-
| MB_TIMEOUT | 120 | Micro Batch: timeout sec |
97-
| HUGGINGFACE_TOKEN | None | huggingface access token |
98-
| IMAGESERVER_URL | None | result image base url |
99-
| SAVE_DIR | static | result image save dir |
100-
| CORS_ALLOW_ORIGINS | [*] | cross origin resource sharing setting for FastAPI |
90+
| Name | Default | Desc |
91+
| ------------------------ | ----------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
92+
| MODEL_ID | CompVis/stable-diffusion-v1-4 | huggingface repo id or model path |
93+
| ENABLE_ATTENTION_SLICING | True | [Enable sliced attention computation.](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline.enable_attention_slicing) |
94+
| CUDA_DEVICE | "cuda" | target cuda device |
95+
| CUDA_DEVICES | [0] | visible cuda device |
96+
| MB_BATCH_SIZE | 1 | Micro Batch: MAX Batch size |
97+
| MB_TIMEOUT | 120 | Micro Batch: timeout sec |
98+
| HUGGINGFACE_TOKEN | None | huggingface access token |
99+
| IMAGESERVER_URL | None | result image base url |
100+
| SAVE_DIR | static | result image save dir |
101+
| CORS_ALLOW_ORIGINS | [*] | cross origin resource sharing setting for FastAPI |
101102

102103
# RUN from code (API)
103104

@@ -117,6 +118,9 @@ python huggingface_model_download.py
117118
## 3. update settings.py in ./core/settings/settings.py
118119
```python
119120
# example
121+
class ModelSetting(BaseSettings):
122+
MODEL_ID: str = "CompVis/stable-diffusion-v1-4" # huggingface repo id
123+
ENABLE_ATTENTION_SLICING: bool = True
120124
...
121125
class Settings(
122126
...
@@ -127,13 +131,9 @@ class Settings(
127131
...
128132
```
129133

130-
## 4. RUN API by uvicorn
134+
## 4. RUN API from code
131135
```bash
132-
cd /REPO/ROOT/DIR/PATH
133-
python3 -m uvicorn app.server:app \
134-
--host 0.0.0.0 \
135-
--port 3000 \
136-
--workers 1
136+
bash docker/api/start.sh
137137
```
138138

139139
# RUN from code (frontend)
@@ -161,27 +161,19 @@ streamlit run inpaint.py
161161
docker-compose build
162162
```
163163

164-
## 2. downlaod and caching huggingface model
165-
```bash
166-
python huggingface_model_download.py
167-
# check stable-diffusion model in huggingface cache dir
168-
[[ -d ~/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4 ]] && echo "exist"
169-
>> exist
170-
```
171-
172164
## 3. update docker-compose.yaml file in repo root
173165
```yaml
174166
version: "3.7"
175-
176167
services:
177168
api:
178169
...
179170
volumes:
180171
# mount huggingface model cache dir path to container root user home dir
181-
- /home/{USER NAME}/.cache/huggingface:/root/.cache/huggingface
172+
- /model:/model # if you load pretraind model
182173
- ...
183174
environment:
184175
...
176+
MODEL_ID: "CompVis/stable-diffusion-v1-4"
185177
HUGGINGFACE_TOKEN: {YOUR HUGGINGFACE ACCESS TOKEN}
186178
...
187179

app/stable_diffusion/manager/manager.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@
66
import sys
77
from random import randint
88
from service_streamer import ThreadedStreamer
9-
from app.stable_diffusion.model import (
10-
build_text2image_pipeline,
11-
build_image2image_pipeline,
12-
build_inpaint_pipeline,
13-
)
9+
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10+
1411
from app.stable_diffusion.manager.schema import (
1512
InpaintTask,
1613
Text2ImageTask,
1714
Image2ImageTask,
1815
)
1916
from core.settings import get_settings
17+
from functools import lru_cache
2018

2119
env = get_settings()
2220

@@ -27,30 +25,64 @@
2725
]
2826

2927

28+
@lru_cache()
29+
def build_pipeline(repo: str, device: str, enable_attention_slicing: bool):
30+
pipe = DiffusionPipeline.from_pretrained(
31+
repo,
32+
torch_dtype=torch.float16,
33+
revision="fp16",
34+
custom_pipeline="lpw_stable_diffusion",
35+
)
36+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
37+
pipe.scheduler.config
38+
)
39+
pipe.safety_checker = lambda images, clip_input: (images, False)
40+
41+
if enable_attention_slicing:
42+
pipe.enable_attention_slicing()
43+
44+
pipe = pipe.to(device)
45+
return pipe
46+
47+
48+
build_pipeline(
49+
repo=env.MODEL_ID,
50+
device=env.CUDA_DEVICE,
51+
enable_attention_slicing=env.ENABLE_ATTENTION_SLICING,
52+
)
53+
54+
3055
class StableDiffusionManager:
3156
def __init__(self):
32-
self.text2image = build_text2image_pipeline()
33-
self.image2image = build_image2image_pipeline()
34-
self.inpaint = build_inpaint_pipeline()
57+
self.pipe = build_pipeline(
58+
repo=env.MODEL_ID,
59+
device=env.CUDA_DEVICE,
60+
enable_attention_slicing=env.ENABLE_ATTENTION_SLICING,
61+
)
3562

63+
@torch.inference_mode()
3664
def predict(
3765
self,
3866
batch: T.List[_StableDiffusionTask],
3967
):
4068
task = batch[0]
41-
pipeline = self.text2image
69+
pipeline = self.pipe
4270
if isinstance(task, Text2ImageTask):
43-
pipeline = self.text2image
71+
pipeline = self.pipe.text2img
4472
elif isinstance(task, Image2ImageTask):
45-
pipeline = self.image2image
73+
pipeline = self.pipe.img2img
4674
elif isinstance(task, InpaintTask):
47-
pipeline = self.inpaint
75+
pipeline = self.pipe.inpaint
76+
else:
77+
raise NotImplementedError
4878

4979
device = env.CUDA_DEVICE
5080

5181
generator = self._get_generator(task, device)
5282
with torch.autocast("cuda" if device != "cpu" else "cpu"):
53-
images = pipeline(**task.dict(), generator=generator)
83+
task = task.dict()
84+
del task["seed"]
85+
images = pipeline(**task, generator=generator).images
5486
if device != "cpu":
5587
torch.cuda.empty_cache()
5688

app/stable_diffusion/manager/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def size_constraint(cls, size):
2323
class Image2ImageTask(BaseModel):
2424
prompt: T.Union[str, T.List[str]] = Field(...)
2525
negative_prompt: T.Union[str, T.List[str]] = Field(...)
26-
init_image: T.Any
26+
image: T.Any
2727
strength: float = Field(..., ge=0.0, le=1.0)
2828
num_inference_steps: int = Field(..., gt=0)
2929
guidance_scale: float = Field(..., ge=0.0)
@@ -33,7 +33,7 @@ class Image2ImageTask(BaseModel):
3333
class InpaintTask(BaseModel):
3434
prompt: T.Union[str, T.List[str]] = Field(...)
3535
negative_prompt: T.Union[str, T.List[str]] = Field(...)
36-
init_image: T.Any
36+
image: T.Any
3737
mask_image: T.Any
3838
strength: float = Field(..., ge=0.0, le=1.0)
3939
num_inference_steps: int = Field(..., gt=0)

app/stable_diffusion/model.py

Lines changed: 0 additions & 149 deletions
This file was deleted.

app/stable_diffusion/service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def image2image(
9292
Image2ImageTask(
9393
prompt=prompt,
9494
negative_prompt=[negative_prompt] * len(prompt),
95-
init_image=init_image,
95+
image=init_image,
9696
strength=strength,
9797
num_inference_steps=num_inference_steps,
9898
guidance_scale=guidance_scale,

core/settings/settings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
class ModelSetting(BaseSettings):
88
MODEL_ID: str = "CompVis/stable-diffusion-v1-4"
9+
ENABLE_ATTENTION_SLICING: bool = True
910

1011

1112
class DeviceSettings(BaseSettings):
@@ -23,8 +24,8 @@ class Settings(
2324
DeviceSettings,
2425
MicroBatchSettings,
2526
):
26-
HUGGINGFACE_TOKEN: str
27-
IMAGESERVER_URL: str = 'http://localhost:3000/images'
27+
HUGGINGFACE_TOKEN: str = "HUGGINGFACE_TOKEN"
28+
IMAGESERVER_URL: str = "http://localhost:3000/images"
2829
SAVE_DIR: str = "static"
2930

3031
CORS_ALLOW_ORIGINS: T.List[str] = ["*"]

0 commit comments

Comments
 (0)