Skip to content

Commit 93dbe5d

Browse files
Merge pull request #1 from rapidrabbit76/develop
fix: CORS, Update: pipeline
2 parents 4e2781a + 1de9edf commit 93dbe5d

File tree

15 files changed

+105
-36
lines changed

15 files changed

+105
-36
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
---
2+
name: Bug Report Template
3+
about: 버그 리포트 템플릿
4+
title: ""
5+
assignees: "yslee"
6+
---
7+
8+
# System info
9+
10+
# Describe of bug
11+
12+
# Code example
13+
14+
# Log or Screenshot
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
name: Feature Request Template
3+
about: 기능 및 추가 요청을 위한 템플릿
4+
title: ""
5+
assignees: "yslee"
6+
---
7+
8+
# Description
9+
10+
# TODO
11+
12+
- [ ] todo
13+
- [ ] todo
14+
15+
# ETC

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Summary
2+
3+
content
4+
5+
# Work
6+
7+
content
8+
9+
# Related issues [optional]
10+
11+
write related issues

api/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .stable_diffusion import StableDiffusionRouter
1+
from .stable_diffusion import router as StableDiffusionRouter
2+
from .home import router as HomeRouter

api/home/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .router import router

api/home/router.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from datetime import datetime
2+
from fastapi import Response
3+
from fastapi_restful.cbv import cbv
4+
from fastapi_restful.inferring_router import InferringRouter
5+
6+
from core.settings import get_settings
7+
8+
9+
router = InferringRouter()
10+
env = get_settings()
11+
12+
13+
@cbv(router)
14+
class Home:
15+
@router.get("/")
16+
async def index(self):
17+
"""ELB check"""
18+
current_time = datetime.utcnow()
19+
msg = f"Notification API (UTC: {current_time.strftime('%Y.%m.%d %H:%M:%S')})"
20+
return Response(msg)

api/stable_diffusion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .router import router as StableDiffusionRouter
1+
from .router import router

app/server.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from datetime import datetime
21
from fastapi.staticfiles import StaticFiles
2+
import typing as T
33

44
from fastapi import FastAPI, Response
55
from fastapi.middleware.cors import CORSMiddleware
6+
from fastapi.middleware import Middleware
67

78
import api
8-
from core.settings import get_settings
9-
10-
env = get_settings()
9+
from core.settings import env
1110

1211

1312
def init_router(app: FastAPI):
@@ -17,28 +16,30 @@ def init_router(app: FastAPI):
1716
name="result image",
1817
)
1918
app.include_router(api.StableDiffusionRouter)
19+
app.include_router(api.HomeRouter)
2020
app.router.redirect_slashes = False
2121

2222

2323
def create_app() -> FastAPI:
24-
app = FastAPI(redoc_url=None)
25-
26-
init_cors(app)
27-
init_middleware(app)
24+
app = FastAPI(
25+
redoc_url=None,
26+
middleware=init_middleware(),
27+
)
2828
init_router(app)
29-
init_settings(app)
3029
return app
3130

3231

33-
def init_cors(app: FastAPI):
34-
app.add_middleware(
35-
CORSMiddleware,
36-
allow_origins=env.CORS_ALLOW_ORIGINS,
37-
)
38-
39-
40-
def init_middleware(app: FastAPI):
41-
pass
32+
def init_middleware() -> T.List[Middleware]:
33+
middleware = [
34+
Middleware(
35+
CORSMiddleware,
36+
allow_origins=env.CORS_ALLOW_ORIGINS,
37+
allow_credentials=env.CORS_CREDENTIALS,
38+
allow_methods=env.CORS_ALLOW_METHODS,
39+
allow_headers=env.CORS_ALLOW_HEADERS,
40+
),
41+
]
42+
return middleware
4243

4344

4445
def init_settings(app: FastAPI):
@@ -50,12 +51,6 @@ def startup_event():
5051
def shutdown_event():
5152
pass
5253

53-
@app.get("/")
54-
async def index():
55-
"""ELB check"""
56-
current_time = datetime.utcnow()
57-
msg = f"Notification API (UTC: {current_time.strftime('%Y.%m.%d %H:%M:%S')})"
58-
return Response(msg)
59-
6054

6155
app = create_app()
56+
init_settings(app)

app/stable_diffusion/manager/manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from functools import lru_cache
22
import typing as T
33
import torch
4+
import sys
5+
from random import randint
46
from service_streamer import ThreadedStreamer
57
from app.stable_diffusion.model import (
68
build_text2image_pipeline,
@@ -13,7 +15,6 @@
1315
Image2ImageTask,
1416
)
1517
from core.settings import get_settings
16-
from core.decorator.singleton import singleton
1718

1819
env = get_settings()
1920

@@ -24,7 +25,6 @@
2425
]
2526

2627

27-
@singleton
2828
class StableDiffusionManager:
2929
def __init__(self):
3030
self.text2image = build_text2image_pipeline()
@@ -44,9 +44,11 @@ def predict(
4444
images = self.predict_inpaint(task)
4545
return [images]
4646

47-
def _get_generator(self, task, device):
47+
def _get_generator(self, task: _StableDiffusionTask, device: str):
4848
generator = torch.Generator(device=device)
49-
generator.manual_seed(task.seed)
49+
seed = task.seed
50+
seed = seed if seed else randint(1, sys.maxsize)
51+
generator.manual_seed(seed)
5052
return generator
5153

5254
def predict_text2image(self, task: Text2ImageTask):

app/stable_diffusion/manager/schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class Text2ImageTask(BaseModel):
99
guidance_scale: float = Field(..., ge=0.0)
1010
height: int
1111
width: int
12-
seed: int = Field(..., gt=0)
12+
seed: T.Optional[int] = Field(..., gt=0)
1313

1414
@validator("height", "width")
1515
def size_constraint(cls, size):
@@ -25,7 +25,7 @@ class Image2ImageTask(BaseModel):
2525
strength: float = Field(..., ge=0.0, le=1.0)
2626
num_inference_steps: int = Field(..., gt=0)
2727
guidance_scale: float = Field(..., ge=0.0)
28-
seed: int = Field(..., gt=0)
28+
seed: T.Optional[int] = Field(..., gt=0)
2929

3030

3131
class InpaintTask(BaseModel):
@@ -35,4 +35,4 @@ class InpaintTask(BaseModel):
3535
strength: float = Field(..., ge=0.0, le=1.0)
3636
num_inference_steps: int = Field(..., gt=0)
3737
guidance_scale: float = Field(..., ge=0.0)
38-
seed: int = Field(..., gt=0)
38+
seed: T.Optional[int] = Field(..., gt=0)

0 commit comments

Comments
 (0)