Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,6 @@ models/*

# IDE related
.vscode

# Custom logs
logs/
11 changes: 8 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: dev run download install lint format check test smoke clean help
.PHONY: dev run download install lint format check test smoke swagger clean help

help: ## Show this help
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
Expand All @@ -18,10 +18,11 @@ run: ## Start server via start.sh
lint: ## Run ruff linter
uv run ruff check slm_server/

format: ## Run ruff formatter
format: ## Run ruff linter (--fix) and formatter
uv run ruff check slm_server/ --fix
uv run ruff format slm_server/

check: lint ## Run linter + formatter check
check: lint ## Run linter + formatter check (CI)
uv run ruff format --check slm_server/

smoke: ## Smoke-test the running server APIs with curl
Expand All @@ -30,6 +31,10 @@ smoke: ## Smoke-test the running server APIs with curl
test: ## Run tests with coverage
uv run pytest tests/ -v --cov=slm_server --cov-report=term-missing

swagger: ## Refresh OpenAPI spec from running server
curl -sf http://localhost:8000/openapi.json | uv run python -c "import sys,json,yaml;yaml.dump(json.load(sys.stdin),sys.stdout,default_flow_style=False,sort_keys=False,allow_unicode=True)" > swagger/openapi.yaml
@echo "swagger/openapi.yaml updated"

clean: ## Remove caches and build artifacts
rm -rf __pycache__ .pytest_cache .ruff_cache .coverage htmlcov build dist *.egg-info
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"fastapi>=0.116.1",
"llama-cpp-python>=0.3.13",
# Cherry-picked PR #1884 (streaming tool use) onto latest upstream.
# Upstream llama-cpp-python silently ignores tool_choice when stream=True;
# this fork adds streaming support.
"llama-cpp-python @ git+https://github.com/XyLearningProgramming/llama-cpp-python.git@main",
"opentelemetry-instrumentation-logging>=0.50b0",
"opentelemetry-instrumentation-fastapi>=0.50b0",
"pydantic-settings>=2.10.1",
Expand Down
102 changes: 100 additions & 2 deletions scripts/smoke.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ curl -sf "$BASE_URL/api/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Say hello in one sentence."}],
"max_tokens": 64
"max_tokens": 512
}' | python3 -m json.tool
echo

Expand All @@ -26,11 +26,109 @@ curl -sf "$BASE_URL/api/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "What is 2+2?"}],
"max_tokens": 32,
"max_tokens": 512,
"stream": true
}'
echo

echo "=== Tool call (no tool_choice, defaults to auto) ==="
TOOL_RESP=$(curl -sf "$BASE_URL/api/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "What is the weather in San Francisco? /no_think"}],
"max_tokens": 256,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
}
},
"required": ["location"]
}
}
}
]
}')
echo "$TOOL_RESP" | python3 -m json.tool

# Verify response has structured tool_calls (not raw <tool_call> in content)
echo "$TOOL_RESP" | python3 -c "
import sys, json
resp = json.load(sys.stdin)
choice = resp['choices'][0]
msg = choice['message']
has_tool = 'tool_calls' in msg and msg['tool_calls']
has_content = 'content' in msg and msg['content']
if not has_tool and not has_content:
print('FAIL: no tool_calls and no content'); sys.exit(1)
if has_tool:
tc = msg['tool_calls'][0]
assert tc['type'] == 'function', f'bad type: {tc[\"type\"]}'
assert 'name' in tc['function'], 'missing function name'
assert 'arguments' in tc['function'], 'missing arguments'
assert '<tool_call>' not in (msg.get('content') or ''), 'raw <tool_call> leaked into content'
assert choice['finish_reason'] == 'tool_calls', f'bad finish_reason: {choice[\"finish_reason\"]}'
print(f'tool_calls: {tc[\"function\"][\"name\"]}({tc[\"function\"][\"arguments\"]})')
else:
print(f'content_only: {msg[\"content\"][:80]}...')
"
echo

echo "=== Tool call streaming (no tool_choice, defaults to auto) ==="
STREAM_RESP=$(curl -sf "$BASE_URL/api/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "What is the weather in San Francisco? /no_think"}],
"max_tokens": 256,
"stream": true,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
}
},
"required": ["location"]
}
}
}
]
}')
echo "$STREAM_RESP"

# Verify no raw <tool_call> tags leaked through as content
echo "$STREAM_RESP" | python3 -c "
import sys
raw = sys.stdin.read()
assert '<tool_call>' not in raw, 'raw <tool_call> tag leaked into stream'
assert '</tool_call>' not in raw, 'raw </tool_call> tag leaked into stream'
# Check for structured tool_calls in at least one chunk
has_tool_calls = '\"tool_calls\"' in raw
has_content = '\"content\"' in raw
if has_tool_calls:
print('streaming tool_calls: structured delta found')
elif has_content:
print('streaming tool_calls: content_only (model chose not to call tool)')
else:
print('FAIL: no tool_calls and no content in stream'); sys.exit(1)
"
echo

echo "=== Embeddings (single) ==="
curl -sf "$BASE_URL/api/v1/embeddings" \
-H "Content-Type: application/json" \
Expand Down
70 changes: 49 additions & 21 deletions slm_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
from fastapi.responses import StreamingResponse
from llama_cpp import CreateChatCompletionStreamResponse, Llama

from slm_server.config import Settings, get_settings
from slm_server.config import Settings, get_model_id, get_settings
from slm_server.embedding import OnnxEmbeddingModel
from slm_server.logging import setup_logging
from slm_server.metrics import setup_metrics
from slm_server.model import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
ModelListResponse,
register_streaming_schema,
)
from slm_server.trace import setup_tracing
from slm_server.utils import (
Expand All @@ -29,14 +31,16 @@
slm_embedding_span,
slm_span,
)
from slm_server.utils.postprocess import StreamPostProcessor, postprocess_completion

# MAX_CONCURRENCY decides how many threads are calling model.
# Default to 1 since llama cpp is designed to be at most efficiency
# for single thread. Meanwhile, value larger than 1 allows
# threads to compete for same resources.
MAX_CONCURRENCY = 1
# Keeps function calling and also compatible with ReAct agents.
CHAT_FORMAT = "chatml-function-calling"
# Use the model's built-in Jinja chat template from the GGUF metadata,
# which handles tool formatting natively (e.g. Qwen3, Llama 3, etc.).
CHAT_FORMAT = None
# Default timeout message in detail field.
DETAIL_SEM_TIMEOUT = "Server is busy, please try again later."
# Status code for semaphore timeout.
Expand All @@ -46,6 +50,10 @@
STATUS_CODE_EXCEPTION = HTTPStatus.INTERNAL_SERVER_ERROR
# Media type for streaming responses.
STREAM_RESPONSE_MEDIA_TYPE = "text/event-stream"
# Schema for streaming responses.
STREAM_RESPONSE_SCHEMA = {
"schema": {"$ref": "#/components/schemas/ChatCompletionChunkResponse"}
}


def get_llm_semaphor() -> asyncio.Semaphore:
Expand Down Expand Up @@ -96,6 +104,8 @@ def get_app() -> FastAPI:
# Setup trace and OTel metrics (this will also instrument FastAPI)
setup_tracing(app, settings.tracing)

register_streaming_schema(app)

return app


Expand Down Expand Up @@ -130,7 +140,7 @@ def raise_as_http_exception() -> Generator[Literal[True], None, None]:


async def run_llm_streaming(
llm: Llama, req: ChatCompletionRequest
llm: Llama, req: ChatCompletionRequest, *, model_id: str
) -> AsyncGenerator[str, None]:
"""Generator that runs the LLM and yields SSE chunks under lock."""
with slm_span(req, is_streaming=True) as span:
Expand All @@ -139,58 +149,77 @@ async def run_llm_streaming(
**req.model_dump(),
)

# Use traced iterator that automatically handles chunk spans
# and parent span updates
processor = StreamPostProcessor(model_id=model_id)
chunk: CreateChatCompletionStreamResponse
for chunk in completion_stream:
set_atrribute_response_stream(span, chunk)
yield f"data: {json.dumps(chunk)}\n\n"
# NOTE: This is a workaround to yield control back to the event loop
# to allow checking for socket after yield and pop in CancelledError.
# Ref: https://github.com/encode/starlette/discussions/1776#discussioncomment-3207518
await asyncio.sleep(0)
for out_chunk in processor.process_chunk(chunk):
set_atrribute_response_stream(span, out_chunk)
yield f"data: {json.dumps(out_chunk)}\n\n"
# NOTE: yield control back to the event loop so starlette
# can detect client disconnects between chunks.
# Ref: https://github.com/encode/starlette/discussions/1776#discussioncomment-3207518
await asyncio.sleep(0)

for out_chunk in processor.flush():
set_atrribute_response_stream(span, out_chunk)
yield f"data: {json.dumps(out_chunk)}\n\n"

yield "data: [DONE]\n\n"


async def run_llm_non_streaming(llm: Llama, req: ChatCompletionRequest):
async def run_llm_non_streaming(
llm: Llama, req: ChatCompletionRequest, *, model_id: str
):
"""Runs the LLM for a non-streaming request under lock."""
with slm_span(req, is_streaming=False) as span:
completion_result = await asyncio.to_thread(
llm.create_chat_completion,
**req.model_dump(),
)
postprocess_completion(completion_result, model_id=model_id)
set_atrribute_response(span, completion_result)

return completion_result


@app.post("/api/v1/chat/completions")
@app.post(
"/api/v1/chat/completions",
response_model=ChatCompletionResponse,
responses={
200: {
"content": {
STREAM_RESPONSE_MEDIA_TYPE: STREAM_RESPONSE_SCHEMA,
},
},
},
)
async def create_chat_completion(
req: ChatCompletionRequest,
llm: Annotated[Llama, Depends(get_llm)],
model_id: Annotated[str, Depends(get_model_id)],
_: Annotated[None, Depends(lock_llm_semaphor)],
__: Annotated[None, Depends(raise_as_http_exception)],
):
) -> ChatCompletionResponse:
"""
Generates a chat completion, handling both streaming and non-streaming cases.
Concurrency is managed by the `locked_llm_session` context manager.
"""
if req.stream:
return StreamingResponse(
run_llm_streaming(llm, req), media_type=STREAM_RESPONSE_MEDIA_TYPE
run_llm_streaming(llm, req, model_id=model_id),
media_type=STREAM_RESPONSE_MEDIA_TYPE,
)
else:
return await run_llm_non_streaming(llm, req)
return await run_llm_non_streaming(llm, req, model_id=model_id)


@app.post("/api/v1/embeddings")
@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(
req: EmbeddingRequest,
emb_model: Annotated[OnnxEmbeddingModel, Depends(get_embedding_model)],
_: Annotated[None, Depends(lock_llm_semaphor)],
__: Annotated[None, Depends(raise_as_http_exception)],
):
) -> EmbeddingResponse:
"""Create embeddings using the dedicated ONNX embedding model."""
with slm_embedding_span(req) as span:
inputs = req.input if isinstance(req.input, list) else [req.input]
Expand All @@ -211,7 +240,6 @@ async def list_models(
settings: Annotated[Settings, Depends(get_settings)],
) -> ModelListResponse:
"""List available models (OpenAI-compatible)."""
chat_model_id = Path(settings.model_path).stem
try:
chat_created = int(Path(settings.model_path).stat().st_mtime)
except (OSError, ValueError):
Expand All @@ -225,7 +253,7 @@ async def list_models(
return ModelListResponse(
data=[
ModelInfo(
id=chat_model_id,
id=settings.chat_model_id,
created=chat_created,
owned_by=settings.model_owner,
),
Expand Down
16 changes: 16 additions & 0 deletions slm_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Annotated
from fastapi import Depends

ENV_PREFIX = "SLM_"

Expand Down Expand Up @@ -85,6 +87,11 @@ class Settings(BaseSettings):
)

model_path: str = Field(MODEL_PATH_DEFAULT, description="Model path for llama_cpp.")
model_id: str = Field(
"",
description="Short model name in API responses (e.g. 'Qwen3-0.6B'). "
"Defaults to the GGUF filename stem when empty.",
)
model_owner: str = Field(
MODEL_OWNER_DEFAULT,
description="Owner label for /models list. Set SLM_MODEL_OWNER to override.",
Expand All @@ -103,6 +110,11 @@ class Settings(BaseSettings):
1, description="Seconds to wait if undergoing another inference."
)

@property
def chat_model_id(self) -> str:
"""Resolved model identifier: explicit ``model_id`` or GGUF stem."""
return self.model_id or Path(self.model_path).stem

embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
logging: LoggingSettings = Field(default_factory=LoggingSettings)
metrics: MetricsSettings = Field(default_factory=MetricsSettings)
Expand All @@ -113,3 +125,7 @@ def get_settings() -> Settings:
if not hasattr(get_settings, "_instance"):
get_settings._instance = Settings()
return get_settings._instance


def get_model_id(settings: Annotated[Settings, Depends(get_settings)]) -> str:
return settings.chat_model_id
Loading