Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use aiohttp inside proxy server && add --disable-cache-status argument #3020

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def add_parser_proxy():
choices=['random', 'min_expected_latency', 'min_observed_latency'],
default='min_expected_latency',
help='the strategy to dispatch requests to nodes')
parser.add_argument('--disable-cache-status',
action='store_true',
help='Whether to disable cache status of the '
'proxy. If set, the proxy will forget the status '
'of the previous time')
ArgumentHelper.api_keys(parser)
ArgumentHelper.ssl(parser)
ArgumentHelper.log_level(parser)
Expand Down
48 changes: 27 additions & 21 deletions lmdeploy/serve/proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from http import HTTPStatus
from typing import Deque, Dict, List, Literal, Optional, Union

import aiohttp
import numpy as np
import requests
import uvicorn
Expand All @@ -20,7 +21,6 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from requests.exceptions import RequestException

from lmdeploy.serve.openai.api_server import check_api_key, create_error_response
from lmdeploy.serve.openai.protocol import ModelCard # noqa: E501
Expand Down Expand Up @@ -70,14 +70,18 @@ class NodeManager:
to.
"""

def __init__(self, config_path: Optional[str] = None, strategy: str = 'min_expected_latency') -> None:
def __init__(self,
config_path: Optional[str] = None,
strategy: str = 'min_expected_latency',
cache_status: Optional[bool] = True) -> None:
self.nodes = dict()
self.strategy = Strategy.from_str(strategy)
self.cache_status = cache_status
self.latencies = dict()
self.config_path = osp.join(osp.dirname(osp.realpath(__file__)), 'proxy_config.yml')
if config_path is not None:
self.config_path = config_path
if osp.exists(self.config_path):
if osp.exists(self.config_path) and self.cache_status:
with open(self.config_path, 'r') as config_file:
self.nodes = yaml.safe_load(config_file)['nodes']
for url, status in self.nodes.items():
Expand All @@ -87,15 +91,17 @@ def __init__(self, config_path: Optional[str] = None, strategy: str = 'min_expec
self.nodes[url] = status
self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self, ), daemon=True)
self.heart_beat_thread.start()
self.aiotimeout = aiohttp.ClientTimeout(total=API_READ_TIMEOUT)

def update_config_file(self):
"""Update the config file."""
nodes = copy.deepcopy(self.nodes)
for url, status in nodes.items():
nodes[url] = status.model_dump()
nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:]
with open(self.config_path, 'w') as config_file: # update cfg yml
yaml.dump(dict(nodes=nodes), config_file)
if self.cache_status:
with open(self.config_path, 'w') as config_file: # update cfg yml
yaml.dump(dict(nodes=nodes), config_file)

def add(self, node_url: str, status: Optional[Status] = None):
"""Add a node to the manager.
Expand Down Expand Up @@ -257,7 +263,7 @@ def handle_api_timeout(self, node_url):
}
return json.dumps(ret).encode() + b'\n'

def stream_generate(self, request: Dict, node_url: str, endpoint: str):
async def stream_generate(self, request: Dict, node_url: str, endpoint: str):
"""Return a generator to handle the input request.

Args:
Expand All @@ -266,16 +272,12 @@ def stream_generate(self, request: Dict, node_url: str, endpoint: str):
endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
response = requests.post(
node_url + endpoint,
json=request,
stream=True,
timeout=(5, API_READ_TIMEOUT),
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\n'):
if chunk:
yield chunk + b'\n\n'
except (Exception, GeneratorExit, RequestException) as e: # noqa
async with aiohttp.ClientSession() as session:
async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response:
async for line in response.content:
if line.strip():
yield line + b'\n\n'
except (Exception, GeneratorExit, aiohttp.ClientError) as e: # noqa
logger.error(f'catched an exception: {e}')
# exception happened, reduce unfinished num
yield self.handle_api_timeout(node_url)
Expand All @@ -289,11 +291,10 @@ async def generate(self, request: Dict, node_url: str, endpoint: str):
endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
import httpx
async with httpx.AsyncClient() as client:
response = await client.post(node_url + endpoint, json=request, timeout=API_READ_TIMEOUT)
return response.text
except (Exception, GeneratorExit, RequestException, asyncio.CancelledError) as e: # noqa
async with aiohttp.ClientSession() as session:
async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response:
return await response.text()
except (Exception, GeneratorExit, aiohttp.ClientError, asyncio.CancelledError) as e: # noqa # yapf: disable
logger.error(f'catched an exception: {e}')
return self.handle_api_timeout(node_url)

Expand Down Expand Up @@ -529,6 +530,7 @@ def proxy(server_name: str = '0.0.0.0',
api_keys: Optional[Union[List[str], str]] = None,
ssl: bool = False,
log_level: str = 'INFO',
disable_cache_status: bool = False,
**kwargs):
"""To launch the proxy server.

Expand All @@ -541,8 +543,12 @@ def proxy(server_name: str = '0.0.0.0',
api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as
a single api_key. Default to None, which means no api key applied.
ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.
log_level (str): Set the log level. Default to INFO.
disable_cache_status (str): Whether to cache the proxy status to
proxy_config.yml.
""" # noqa
node_manager.strategy = Strategy.from_str(strategy)
node_manager.cache_status = not disable_cache_status
if api_keys is not None:
if isinstance(api_keys, str):
api_keys = api_keys.split(',')
Expand Down