Skip to content

Commit 79e3e17

Browse files
hlkyNarsil
andauthored
latent-to-image (#469)
* latent-to-image * needs_upcasting, unscale/denormalize * inputs * Update .gitignore Co-authored-by: Nicolas Patry <[email protected]> * use base64 * make --------- Co-authored-by: Nicolas Patry <[email protected]>
1 parent 329d9bb commit 79e3e17

File tree

18 files changed

+915
-1
lines changed

18 files changed

+915
-1
lines changed

api_inference_community/validation.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,12 @@ def check_inputs(inputs, tag):
196196
IMAGE_OUTPUTS = {
197197
"image-to-image",
198198
"text-to-image",
199+
"latent-to-image",
199200
}
200201

202+
TENSOR_INPUTS = {
203+
"latent-to-image",
204+
}
201205

202206
TEXT_INPUTS = {
203207
"conversational",
@@ -218,7 +222,7 @@ def check_inputs(inputs, tag):
218222
"zero-shot-classification",
219223
}
220224

221-
KNOWN_TASKS = AUDIO_INPUTS.union(IMAGE_INPUTS).union(TEXT_INPUTS)
225+
KNOWN_TASKS = AUDIO_INPUTS.union(IMAGE_INPUTS).union(TEXT_INPUTS).union(TENSOR_INPUTS)
222226

223227
AUDIO = [
224228
"flac",
@@ -266,6 +270,8 @@ def normalize_payload(
266270
return normalize_payload_image(bpayload)
267271
elif task in TEXT_INPUTS:
268272
return normalize_payload_nlp(bpayload, task)
273+
elif task in TENSOR_INPUTS:
274+
return normalize_payload_tensor(bpayload)
269275
else:
270276
raise EnvironmentError(
271277
f"The task `{task}` is not recognized by api-inference-community"
@@ -407,3 +413,28 @@ def normalize_payload_nlp(bpayload: bytes, task: str) -> Tuple[Any, Dict]:
407413
check_params(parameters, task)
408414
check_inputs(inputs, task)
409415
return inputs, parameters
416+
417+
418+
def normalize_payload_tensor(bpayload: bytes) -> Tuple[Any, Dict]:
419+
import torch
420+
421+
data = json.loads(bpayload)
422+
tensor = data["inputs"]
423+
tensor = b64decode(tensor.encode("utf-8"))
424+
parameters = data.get("parameters", {})
425+
if "shape" not in parameters:
426+
raise ValueError("Expected `shape` in parameters.")
427+
if "dtype" not in parameters:
428+
raise ValueError("Expected `dtype` in parameters.")
429+
430+
DTYPE_MAP = {
431+
"float16": torch.float16,
432+
"float32": torch.float32,
433+
"bfloat16": torch.bfloat16,
434+
}
435+
436+
shape = parameters.pop("shape")
437+
dtype = DTYPE_MAP.get(parameters.pop("dtype"))
438+
tensor = torch.frombuffer(bytearray(tensor), dtype=dtype).reshape(shape)
439+
440+
return tensor, parameters
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
2+
LABEL maintainer="Nicolas Patry <[email protected]>"
3+
4+
# Add any system dependency here
5+
# RUN apt-get update -y && apt-get install libXXX -y
6+
7+
ENV DEBIAN_FRONTEND=noninteractive
8+
9+
# Install prerequisites
10+
RUN apt-get update && \
11+
apt-get install -y build-essential libssl-dev zlib1g-dev libbz2-dev \
12+
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
13+
xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git
14+
15+
# Install pyenv
16+
RUN curl https://pyenv.run | bash
17+
18+
# Set environment variables for pyenv
19+
ENV PYENV_ROOT=/root/.pyenv
20+
ENV PATH=$PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH
21+
22+
# Install your desired Python version
23+
ARG PYTHON_VERSION=3.9.1
24+
RUN pyenv install $PYTHON_VERSION && \
25+
pyenv global $PYTHON_VERSION && \
26+
pyenv rehash
27+
28+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
29+
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
30+
31+
WORKDIR /app
32+
COPY ./requirements.txt /app
33+
RUN pip install --no-cache-dir -r requirements.txt
34+
35+
# Most DL models are quite large in terms of memory, using workers is a HUGE
36+
# slowdown because of the fork and GIL with python.
37+
# Using multiple pods seems like a better default strategy.
38+
# Feel free to override if it does not make sense for your library.
39+
ARG max_workers=1
40+
ENV MAX_WORKERS=$max_workers
41+
ENV HUGGINGFACE_HUB_CACHE=/data
42+
ENV DIFFUSERS_CACHE=/data
43+
44+
# Necessary on GPU environment docker.
45+
# TIMEOUT env variable is used by nvcr.io/nvidia/pytorch:xx for another purpose
46+
# rendering TIMEOUT defined by uvicorn impossible to use correctly
47+
# We're overriding it to be renamed UVICORN_TIMEOUT
48+
# UVICORN_TIMEOUT is a useful variable for very large models that take more
49+
# than 30s (the default) to load in memory.
50+
# If UVICORN_TIMEOUT is too low, uvicorn will simply never loads as it will
51+
# kill workers all the time before they finish.
52+
COPY --from=tiangolo/uvicorn-gunicorn:python3.8 /app/ /app
53+
COPY --from=tiangolo/uvicorn-gunicorn:python3.8 /start.sh /
54+
COPY --from=tiangolo/uvicorn-gunicorn:python3.8 /gunicorn_conf.py /
55+
COPY app/ /app/app
56+
57+
COPY ./prestart.sh /app/
58+
59+
RUN sed -i 's/TIMEOUT/UVICORN_TIMEOUT/g' /gunicorn_conf.py
60+
61+
CMD ["/start.sh"]

docker_images/latent-to-image/app/__init__.py

Whitespace-only changes.
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
"""
2+
This file allows users to spawn some side service helping with giving a better view on the main ASGI app status.
3+
The issue with the status route of the main application is that it gets unresponsive as soon as all workers get busy.
4+
Thus, you cannot really use the said route as a healthcheck to decide whether your app is healthy or not.
5+
Instead this module allows you to distinguish between a dead service (not able to even tcp connect to app port)
6+
and a busy one (able to connect but not to process a trivial http request in time) as both states should result in
7+
different actions (restarting the service vs scaling it). It also exposes some data to be
8+
consumed as custom metrics, for example to be used in autoscaling decisions.
9+
"""
10+
11+
import asyncio
12+
import functools
13+
import logging
14+
import os
15+
from collections import namedtuple
16+
from typing import Optional
17+
18+
import aiohttp
19+
import psutil
20+
from starlette.applications import Starlette
21+
from starlette.requests import Request
22+
from starlette.responses import Response
23+
from starlette.routing import Route
24+
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
METRICS = ""
30+
STATUS_OK = 0
31+
STATUS_BUSY = 1
32+
STATUS_ERROR = 2
33+
34+
35+
def metrics():
36+
logging.debug("Requesting metrics")
37+
return METRICS
38+
39+
40+
async def metrics_route(_request: Request) -> Response:
41+
return Response(content=metrics())
42+
43+
44+
routes = [
45+
Route("/{whatever:path}", metrics_route),
46+
]
47+
48+
app = Starlette(routes=routes)
49+
50+
51+
def reset_logging():
52+
if os.environ.get("METRICS_DEBUG", "false").lower() in ["1", "true"]:
53+
level = logging.DEBUG
54+
else:
55+
level = logging.INFO
56+
logging.basicConfig(
57+
level=level,
58+
format="healthchecks - %(asctime)s - %(levelname)s - %(message)s",
59+
force=True,
60+
)
61+
62+
63+
@app.on_event("startup")
64+
async def startup_event():
65+
reset_logging()
66+
# Link between `api-inference-community` and framework code.
67+
asyncio.create_task(compute_metrics_loop(), name="compute_metrics")
68+
69+
70+
@functools.lru_cache()
71+
def get_listening_port():
72+
logger.debug("Get listening port")
73+
main_app_port = os.environ.get("MAIN_APP_PORT", "80")
74+
try:
75+
main_app_port = int(main_app_port)
76+
except ValueError:
77+
logger.warning(
78+
"Main app port cannot be converted to an int, skipping and defaulting to 80"
79+
)
80+
main_app_port = 80
81+
return main_app_port
82+
83+
84+
async def find_app_process(
85+
listening_port: int,
86+
) -> Optional[namedtuple("addr", ["ip", "port"])]: # noqa
87+
connections = psutil.net_connections()
88+
app_laddr = None
89+
for c in connections:
90+
if c.laddr.port != listening_port:
91+
logger.debug("Skipping listening connection bound to excluded port %s", c)
92+
continue
93+
if c.status == psutil.CONN_LISTEN:
94+
logger.debug("Found LISTEN conn %s", c)
95+
candidate = c.pid
96+
try:
97+
p = psutil.Process(candidate)
98+
except psutil.NoSuchProcess:
99+
continue
100+
if p.name() == "gunicorn":
101+
logger.debug("Found gunicorn process %s", p)
102+
app_laddr = c.laddr
103+
break
104+
105+
return app_laddr
106+
107+
108+
def count_current_conns(app_port: int) -> str:
109+
estab = []
110+
conns = psutil.net_connections()
111+
112+
# logger.debug("Connections %s", conns)
113+
114+
for c in conns:
115+
if c.status != psutil.CONN_ESTABLISHED:
116+
continue
117+
if c.laddr.port == app_port:
118+
estab.append(c)
119+
current_conns = len(estab)
120+
logger.info("Current count of established connections to app: %d", current_conns)
121+
122+
curr_conns_str = """# HELP inference_app_established_conns Established connection count for a given app.
123+
# TYPE inference_app_established_conns gauge
124+
inference_app_established_conns{{port="{:d}"}} {:d}
125+
""".format(
126+
app_port, current_conns
127+
)
128+
return curr_conns_str
129+
130+
131+
async def status_with_timeout(
132+
listening_port: int, app_laddr: Optional[namedtuple("addr", ["ip", "port"])] # noqa
133+
) -> str:
134+
logger.debug("Checking application status")
135+
136+
status = STATUS_OK
137+
138+
if not app_laddr:
139+
status = STATUS_ERROR
140+
else:
141+
try:
142+
async with aiohttp.ClientSession(
143+
timeout=aiohttp.ClientTimeout(total=0.5)
144+
) as session:
145+
url = "http://{}:{:d}/".format(app_laddr.ip, app_laddr.port)
146+
async with session.get(url) as resp:
147+
status_code = resp.status
148+
status_text = await resp.text()
149+
logger.debug("Status code %s and text %s", status_code, status_text)
150+
if status_code != 200 or status_text != '{"ok":"ok"}':
151+
status = STATUS_ERROR
152+
except asyncio.TimeoutError:
153+
logger.debug("Asgi app seems busy, unable to reach it before timeout")
154+
status = STATUS_BUSY
155+
except Exception as e:
156+
logger.exception(e)
157+
status = STATUS_ERROR
158+
159+
status_str = """# HELP inference_app_status Application health status (0: ok, 1: busy, 2: error).
160+
# TYPE inference_app_status gauge
161+
inference_app_status{{port="{:d}"}} {:d}
162+
""".format(
163+
listening_port, status
164+
)
165+
166+
return status_str
167+
168+
169+
async def single_metrics_compute():
170+
global METRICS
171+
listening_port = get_listening_port()
172+
app_laddr = await find_app_process(listening_port)
173+
current_conns = count_current_conns(listening_port)
174+
status = await status_with_timeout(listening_port, app_laddr)
175+
176+
# Assignment is atomic, we should be safe without locking
177+
METRICS = current_conns + status
178+
179+
# Persist metrics to the local ephemeral as well
180+
metrics_file = os.environ.get("METRICS_FILE")
181+
if metrics_file:
182+
with open(metrics_file) as f:
183+
f.write(METRICS)
184+
185+
186+
@functools.lru_cache()
187+
def get_polling_sleep():
188+
logger.debug("Get polling sleep interval")
189+
sleep_value = os.environ.get("METRICS_POLLING_INTERVAL", 10)
190+
try:
191+
sleep_value = float(sleep_value)
192+
except ValueError:
193+
logger.warning(
194+
"Unable to cast METRICS_POLLING_INTERVAL env value %s to float. Defaulting to 10.",
195+
sleep_value,
196+
)
197+
sleep_value = 10.0
198+
return sleep_value
199+
200+
201+
@functools.lru_cache()
202+
def get_initial_delay():
203+
logger.debug("Get polling initial delay")
204+
sleep_value = os.environ.get("METRICS_INITIAL_DELAY", 30)
205+
try:
206+
sleep_value = float(sleep_value)
207+
except ValueError:
208+
logger.warning(
209+
"Unable to cast METRICS_INITIAL_DELAY env value %s to float. "
210+
"Defaulting to 30.",
211+
sleep_value,
212+
)
213+
sleep_value = 30.0
214+
return sleep_value
215+
216+
217+
async def compute_metrics_loop():
218+
initial_delay = get_initial_delay()
219+
220+
await asyncio.sleep(initial_delay)
221+
222+
polling_sleep = get_polling_sleep()
223+
while True:
224+
await asyncio.sleep(polling_sleep)
225+
try:
226+
await single_metrics_compute()
227+
except Exception as e:
228+
logger.error("Something wrong occurred while computing metrics")
229+
logger.exception(e)
230+
231+
232+
if __name__ == "__main__":
233+
reset_logging()
234+
try:
235+
single_metrics_compute()
236+
logger.info("Metrics %s", metrics())
237+
except Exception as exc:
238+
logging.exception(exc)
239+
raise

0 commit comments

Comments
 (0)