Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions openviking/server/routers/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_bot_url() -> str:
return BOT_API_URL


async def verify_auth(request: Request) -> Optional[str]:
def extract_auth_token(request: Request) -> Optional[str]:
"""Extract and return authorization token from request."""
# Try X-API-Key header first
api_key = request.headers.get("X-API-Key")
Expand All @@ -52,6 +52,17 @@ async def verify_auth(request: Request) -> Optional[str]:
return None


def require_auth_token(request: Request) -> str:
"""Return an auth token or raise 401 for bot proxy endpoints."""
auth_token = extract_auth_token(request)
if not auth_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication token",
)
return auth_token


@router.get("/health")
async def health_check(request: Request):
"""Health check endpoint for Bot API.
Expand Down Expand Up @@ -92,7 +103,7 @@ async def chat(request: Request):
Proxies the request to Vikingbot OpenAPIChannel.
"""
bot_url = get_bot_url()
auth_token = await verify_auth(request)
auth_token = require_auth_token(request)

# Read request body
try:
Expand All @@ -106,9 +117,7 @@ async def chat(request: Request):
try:
async with httpx.AsyncClient() as client:
# Build headers
headers = {"Content-Type": "application/json"}
if auth_token:
headers["X-API-Key"] = auth_token
headers = {"Content-Type": "application/json", "X-API-Key": auth_token}

# Forward to Vikingbot OpenAPIChannel chat endpoint
response = await client.post(
Expand Down Expand Up @@ -146,7 +155,7 @@ async def chat_stream(request: Request):
Proxies the request to Vikingbot OpenAPIChannel with SSE streaming.
"""
bot_url = get_bot_url()
auth_token = await verify_auth(request)
auth_token = require_auth_token(request)

# Read request body
try:
Expand All @@ -162,9 +171,7 @@ async def event_stream() -> AsyncGenerator[str, None]:
try:
async with httpx.AsyncClient() as client:
# Build headers
headers = {"Content-Type": "application/json"}
if auth_token:
headers["X-API-Key"] = auth_token
headers = {"Content-Type": "application/json", "X-API-Key": auth_token}

# Forward to Vikingbot OpenAPIChannel stream endpoint
async with client.stream(
Expand Down
76 changes: 76 additions & 0 deletions tests/server/test_bot_proxy_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: Apache-2.0

"""Regression tests for bot proxy endpoint auth enforcement."""

import httpx
import pytest
import pytest_asyncio
from fastapi import FastAPI, HTTPException, Request

import openviking.server.routers.bot as bot_router_module


def make_request(headers: dict[str, str]) -> Request:
"""Create a minimal request object with the provided headers."""
return Request(
{
"type": "http",
"method": "POST",
"path": "/",
"headers": [
(key.lower().encode("latin-1"), value.encode("latin-1"))
for key, value in headers.items()
],
"query_string": b"",
}
)


@pytest_asyncio.fixture
async def bot_auth_client() -> httpx.AsyncClient:
"""Client mounted with bot router and bot backend configured."""
app = FastAPI()
old_bot_api_url = bot_router_module.BOT_API_URL
bot_router_module.set_bot_api_url("http://bot-backend.local")
app.include_router(bot_router_module.router, prefix="/bot/v1")
transport = httpx.ASGITransport(app=app)
try:
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
yield client
finally:
bot_router_module.BOT_API_URL = old_bot_api_url


@pytest.mark.parametrize(
("headers", "expected"),
[
({"X-API-Key": "test-key"}, "test-key"),
({"Authorization": "Bearer test-token"}, "test-token"),
],
)
def test_extract_auth_token(headers: dict[str, str], expected: str):
"""Accepted auth header formats should both produce a token."""
assert bot_router_module.extract_auth_token(make_request(headers)) == expected


def test_require_auth_token_rejects_missing_token():
"""Missing credentials should raise a 401 before proxying."""
with pytest.raises(HTTPException) as exc_info:
bot_router_module.require_auth_token(make_request({}))

assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Missing authentication token"


@pytest.mark.asyncio
@pytest.mark.parametrize("path", ["/bot/v1/chat", "/bot/v1/chat/stream"])
async def test_bot_proxy_requires_auth_token(bot_auth_client: httpx.AsyncClient, path: str):
"""Bot proxy endpoints should reject missing auth with 401."""
response = await bot_auth_client.post(
path,
json={"message": "hello"},
)

assert response.status_code == 401
assert response.json()["detail"] == "Missing authentication token"
Loading