Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ ELASTIC_PASSWORD=
PAGE_SIZE=1000
GCS_BUCKET=
GCS_PREFIX=

# Rate Limiting
RATE_LIMIT=10/minute
68 changes: 60 additions & 8 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,61 @@
import asyncio
from typing import Optional, Dict, Any
from datetime import datetime
import logging

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import uvicorn
import json

from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

from agents import NeuroscienceAssistant

load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Rate limit configuration
RATE_LIMIT = os.getenv("RATE_LIMIT", "10/minute")

# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)

# FastAPI app + CORS
app = FastAPI(
title="KnowledgeSpace AI",
description="Neuroscience Dataset Discovery Assistant",
version="2.0.0",
)

# Attach limiter to app
app.state.limiter = limiter


# Custom rate limit exceeded handler
@app.exception_handler(RateLimitExceeded)
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
logger.warning(
f"Rate limit exceeded for IP: {get_remote_address(request)} "
f"on path: {request.url.path}"
)
return JSONResponse(
status_code=429,
content={
"detail": "Too many requests. Please wait and try again.",
"retry_after": str(exc.detail),
},
)


app.add_middleware(
CORSMiddleware,
allow_origins=[o.strip() for o in os.getenv("CORS_ALLOW_ORIGINS", "*").split(",")],
Expand All @@ -34,6 +70,7 @@
# Initialize the assistant with vector search agent on startup
assistant = NeuroscienceAssistant()


# Models
class ChatMessage(BaseModel):
query: str = Field(..., description="The user's query")
Expand All @@ -54,9 +91,8 @@ class ChatResponse(BaseModel):
# Lightweight health helpers

def _vector_check_sync() -> bool:

try:
from retrieval import Retriever # local import to avoid import penalty on startup
from retrieval import Retriever
r = Retriever()
return bool(getattr(r, "is_enabled", False))
except Exception:
Expand Down Expand Up @@ -110,15 +146,31 @@ async def health():


@app.post("/api/chat", response_model=ChatResponse, tags=["Chat"])
async def chat_endpoint(msg: ChatMessage):
@limiter.limit(RATE_LIMIT)
async def chat_endpoint(request: Request, msg: ChatMessage):
try:
start_time = time.time()

# Log the request
client_ip = get_remote_address(request)
logger.info(
f"Chat request from {client_ip} | "
f"session: {msg.session_id} | "
f"query length: {len(msg.query)}"
)

response_text = await assistant.handle_chat(
session_id=msg.session_id or "default",
query=msg.query,
reset=bool(msg.reset),
)
process_time = time.time() - start_time

logger.info(
f"Chat response sent to {client_ip} | "
f"process_time: {process_time:.2f}s"
)

metadata = {
"process_time": process_time,
"session_id": msg.session_id,
Expand All @@ -132,14 +184,13 @@ async def chat_endpoint(msg: ChatMessage):
detail="Request timed out. Please try with a simpler query.",
)
except Exception as e:
logger.error(f"Chat error: {e}")
return ChatResponse(
response=f"Error: {e}",
metadata={"error": True, "session_id": msg.session_id},
)




@app.post("/api/session/reset", tags=["Chat"])
async def reset_session(payload: Dict[str, str]):
sid = (payload or {}).get("session_id") or "default"
Expand All @@ -149,12 +200,13 @@ async def reset_session(payload: Dict[str, str]):

# Entry point
if __name__ == "__main__":
logger.info(f"Starting server with rate limit: {RATE_LIMIT}")
env = os.getenv("ENVIRONMENT", "production").lower()
uvicorn.run(
"main:app",
host=os.getenv("HOST", "0.0.0.0"),
port=int(os.getenv("PORT", "8000")),
reload=True,
reload=True,
log_level="info",
proxy_headers=True,
)
)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"requests>=2.32.4",
"scikit-learn>=1.7.0",
"sentence-transformers>=3.0.0",
"slowapi>=0.1.9",
"sqlalchemy>=2.0.42",
"torch>=2.7.1",
"tqdm>=4.67.1",
Expand All @@ -37,4 +38,4 @@ dev = [
"isort>=5.12.0",
"flake8>=6.0.0",
"mypy>=1.0.0",
]
]