Skip to content

Commit

Permalink
FEAT: support distributed inference for sglang (xorbitsai#2877)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Feb 19, 2025
1 parent ce8991a commit 8f86c1d
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 32 deletions.
19 changes: 19 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def __init__(
model_description: Optional["ModelDescription"] = None,
request_limits: Optional[int] = None,
xavier_config: Optional[Dict] = None,
n_worker: Optional[int] = 1,
shard: Optional[int] = 0,
driver_info: Optional[dict] = None, # for model across workers
):
super().__init__()
from ..model.llm.lmdeploy.core import LMDeployModel
Expand Down Expand Up @@ -263,6 +266,10 @@ def __init__(
"quantization": self._model_description.get("quantization", "none"),
}
self._loop: Optional[asyncio.AbstractEventLoop] = None
# model across workers
self._n_worker = n_worker
self._shard = shard
self._driver_info = driver_info

self._scheduler_ref = None
self._text_to_image_scheduler_ref = None
Expand Down Expand Up @@ -455,6 +462,8 @@ async def load(self):
i += 1
try:
self._model.load()
if hasattr(self._model, "driver_info"):
self._driver_info = self._model.driver_info
break
except Exception as e:
if (
Expand All @@ -477,6 +486,10 @@ async def load(self):
)
logger.info(f"{self} loaded")

async def wait_for_load(self):
if hasattr(self._model, "wait_for_load"):
self._model.wait_for_load()

def model_uid(self):
return (
self._model.model_uid
Expand All @@ -488,6 +501,12 @@ def model_uid(self):
)
)

def get_driver_info(self):
# driver info is used for model across workers,
# the driver model actor(always the first worker)
# will hold driver information includes dist store etc.
return self._driver_info

async def _handle_oom_error(self, ex):
error_message = (
f"Model actor is out of memory, model id: {self.model_uid()}, error: {ex}"
Expand Down
1 change: 1 addition & 0 deletions xinference/core/status_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class InstanceInfo(BaseModel):
replica: int
status: str
instance_created_ts: int
n_worker: Optional[int] = 1

def update(self, **kwargs):
for field, value in kwargs.items():
Expand Down
Loading

0 comments on commit 8f86c1d

Please sign in to comment.