Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ The backend requires specific environment variables to connect to **Google Cloud
| `BQ_DATASET_ID` | BigQuery dataset ID | Dataset containing KnowledgeSpace metadata |
| `INDEX_ENDPOINT_ID` | Vertex AI Vector Search endpoint | ID of deployed vector index for RAG |
| `ELASTIC_BASE_URL` | Elasticsearch base URL | URL of the text search engine |
| `SESSION_MAX_COUNT` | Maximum in-memory chat sessions | Defaults to `1000` |
| `SESSION_TTL_SECONDS` | Inactive session eviction time | Defaults to `3600`; set `0` to disable TTL |
| `SESSION_HISTORY_LIMIT` | Messages retained per session | Defaults to `20` |

The backend keeps chat history and paginated search results in memory. Session
LRU and TTL eviction bound that memory for long-running deployments.

---

Expand Down
264 changes: 192 additions & 72 deletions backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import re
import json
import asyncio
from collections import OrderedDict
from enum import Enum
from time import monotonic
from typing import Dict, List, Optional, TypedDict, Any
import logging

Expand Down Expand Up @@ -88,6 +90,21 @@ def _get_genai_client():
FLASH_LITE_MODEL = os.getenv("GEMINI_FLASH_LITE_MODEL", "gemini-2.5-flash-lite")


def _env_int_at_least(name: str, default: int, minimum: int) -> int:
raw = os.getenv(name)
if raw is None:
return default
try:
value = int(raw)
except ValueError:
logger.warning("Invalid %s=%r; using default %s", name, raw, default)
return default
if value < minimum:
logger.warning("%s must be >= %s; using default %s", name, minimum, default)
return default
return value


# Query intent/types
class QueryIntent(Enum):
DATA_DISCOVERY = "data_discovery"
Expand Down Expand Up @@ -503,8 +520,14 @@ async def generate_final_response(state: AgentState) -> AgentState:

class NeuroscienceAssistant:
def __init__(self):
self.chat_history: Dict[str, List[str]] = {}
self.session_memory: Dict[str, Dict[str, Any]] = {}
self.max_sessions = _env_int_at_least("SESSION_MAX_COUNT", 1000, 1)
self.session_ttl_seconds = _env_int_at_least("SESSION_TTL_SECONDS", 3600, 0)
self.history_limit = _env_int_at_least("SESSION_HISTORY_LIMIT", 20, 1)
self.chat_history: OrderedDict[str, List[str]] = OrderedDict()
self.session_memory: OrderedDict[str, Dict[str, Any]] = OrderedDict()
self._session_last_seen: OrderedDict[str, float] = OrderedDict()
self._session_tokens: Dict[str, object] = {}
self._session_active_counts: Dict[str, int] = {}
self.graph = self._build_graph()

def _build_graph(self):
Expand All @@ -520,85 +543,182 @@ def _build_graph(self):
workflow.add_edge("generate_response", END)
return workflow.compile()

def reset_session(self, session_id: str):
def _drop_session(self, session_id: str):
self.chat_history.pop(session_id, None)
self.session_memory.pop(session_id, None)
self._session_last_seen.pop(session_id, None)
self._session_tokens.pop(session_id, None)

def _session_is_active(self, session_id: str) -> bool:
return self._session_active_counts.get(session_id, 0) > 0

def _reserve_session(self, session_id: str):
self._session_active_counts[session_id] = self._session_active_counts.get(session_id, 0) + 1

def _release_session(self, session_id: str):
count = self._session_active_counts.get(session_id, 0)
if count <= 1:
self._session_active_counts.pop(session_id, None)
else:
self._session_active_counts[session_id] = count - 1
self._evict_expired_sessions(monotonic())
self._evict_overflow_sessions()

def _evict_expired_sessions(self, now: float):
if self.session_ttl_seconds <= 0:
return
cutoff = now - self.session_ttl_seconds
for session_id, last_seen in list(self._session_last_seen.items()):
if last_seen >= cutoff:
break
if self._session_is_active(session_id):
continue
self._drop_session(session_id)

def _evict_overflow_sessions(self, protected_session_id: Optional[str] = None):
while len(self._session_last_seen) > self.max_sessions:
evicted = False
for session_id in list(self._session_last_seen):
if session_id == protected_session_id:
continue
if self._session_is_active(session_id):
continue
self._drop_session(session_id)
evicted = True
break
if not evicted:
break

def _touch_session(self, session_id: str):
now = monotonic()
self._evict_expired_sessions(now)
if session_id in self._session_last_seen:
self._session_last_seen.move_to_end(session_id)
self._session_last_seen[session_id] = now
if session_id in self.chat_history:
self.chat_history.move_to_end(session_id)
if session_id in self.session_memory:
self.session_memory.move_to_end(session_id)
self._evict_overflow_sessions(protected_session_id=session_id)

def _trim_history(self, session_id: str):
history = self.chat_history[session_id]
if len(history) > self.history_limit:
self.chat_history[session_id] = history[-self.history_limit:]

def _ensure_session(self, session_id: str):
self._touch_session(session_id)
if session_id not in self.chat_history:
self.chat_history[session_id] = []
if session_id not in self._session_tokens:
self._session_tokens[session_id] = object()

def _session_is_current(self, session_id: str, token: object) -> bool:
return (
self._session_tokens.get(session_id) is token
and session_id in self._session_last_seen
and session_id in self.chat_history
)

def reset_session(self, session_id: str):
self._drop_session(session_id)


async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> str:
try:
if reset:
self.reset_session(session_id)
if session_id not in self.chat_history:
self.chat_history[session_id] = []

more_count = _is_more_query(query)
mem = self.session_memory.get(session_id, {})
if more_count is not None or (query.strip().lower() in {"more", "next", "continue", "more please", "show more", "keep going"}):
all_results = mem.get("all_results", [])
if not all_results:
return "There are no earlier results to continue. Ask me for a dataset (e.g., 'human EEG BIDS')."
page_size = more_count or mem.get("page_size", 15)
page = mem.get("page", 1) + 1
start = (page - 1) * page_size
batch = all_results[start:start + page_size]
if not batch:
return "You've reached the end of the results. Try refining the query."
intents = mem.get("intents", [QueryIntent.DATA_DISCOVERY.value])
effective_query = mem.get("effective_query", "")
prev_text = mem.get("last_text", "")

try:
text = await call_gemini_for_final_synthesis(
effective_query, batch, intents, start_number=start + 1, previous_text=prev_text
)
except Exception:
text = "Unable to process your request. Please try again."
mem.update({
"page": page,
"page_size": page_size,
"last_text": f"{prev_text}\n\n{text}"[-12000:],
})
self.session_memory[session_id] = mem
self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {text}"])
if len(self.chat_history[session_id]) > 20:
self.chat_history[session_id] = self.chat_history[session_id][-20:]
return text

initial_state: AgentState = {
"session_id": session_id,
"query": query,
"history": self.chat_history[session_id][-10:],
"keywords": [],
"effective_query": "",
"intents": [],
"ks_results": [],
"vector_results": [],
"final_results": [],
"all_results": [],
"start_number": 1,
"previous_text": "",
"final_response": "",
}
final_state = await self.graph.ainvoke(initial_state)
response_text = final_state.get("final_response", "I encountered an unexpected empty response.")

self.session_memory[session_id] = {
"all_results": final_state.get("all_results", []),
"page": 1,
"page_size": 15,
"effective_query": final_state.get("effective_query", initial_state["query"]),
"keywords": final_state.get("keywords", []),
"intents": final_state.get("intents", [QueryIntent.DATA_DISCOVERY.value]),
"last_text": response_text,
}

self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {response_text}"])
if len(self.chat_history[session_id]) > 20:
self.chat_history[session_id] = self.chat_history[session_id][-20:]
return response_text
self._ensure_session(session_id)
session_token = self._session_tokens[session_id]
self._reserve_session(session_id)
try:
more_count = _is_more_query(query)
mem = self.session_memory.get(session_id, {})
if more_count is not None or (
query.strip().lower() in {"more", "next", "continue", "more please", "show more", "keep going"}
):
all_results = mem.get("all_results", [])
if not all_results:
return (
"There are no earlier results to continue. "
"Ask me for a dataset (e.g., 'human EEG BIDS')."
)
page_size = more_count or mem.get("page_size", 15)
page = mem.get("page", 1) + 1
start = (page - 1) * page_size
batch = all_results[start:start + page_size]
if not batch:
return "You've reached the end of the results. Try refining the query."
intents = mem.get("intents", [QueryIntent.DATA_DISCOVERY.value])
effective_query = mem.get("effective_query", "")
prev_text = mem.get("last_text", "")

try:
text = await call_gemini_for_final_synthesis(
effective_query,
batch,
intents,
start_number=start + 1,
previous_text=prev_text,
)
except Exception:
text = "Unable to process your request. Please try again."
if not self._session_is_current(session_id, session_token):
return text
self._touch_session(session_id)
mem.update({
"page": page,
"page_size": page_size,
"last_text": f"{prev_text}\n\n{text}"[-12000:],
})
self.session_memory[session_id] = mem
self.session_memory.move_to_end(session_id)
self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {text}"])
self._trim_history(session_id)
return text

initial_state: AgentState = {
"session_id": session_id,
"query": query,
"history": self.chat_history[session_id][-10:],
"keywords": [],
"effective_query": "",
"intents": [],
"ks_results": [],
"vector_results": [],
"final_results": [],
"all_results": [],
"start_number": 1,
"previous_text": "",
"final_response": "",
}
final_state = await self.graph.ainvoke(initial_state)
response_text = final_state.get(
"final_response",
"I encountered an unexpected empty response.",
)

if not self._session_is_current(session_id, session_token):
return response_text
self._touch_session(session_id)
self.session_memory[session_id] = {
"all_results": final_state.get("all_results", []),
"page": 1,
"page_size": 15,
"effective_query": final_state.get("effective_query", initial_state["query"]),
"keywords": final_state.get("keywords", []),
"intents": final_state.get("intents", [QueryIntent.DATA_DISCOVERY.value]),
"last_text": response_text,
}
self.session_memory.move_to_end(session_id)

self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {response_text}"])
self._trim_history(session_id)
return response_text
finally:
self._release_session(session_id)
except Exception as e:
logger.error("Error in handle_chat: %s", e)
import traceback
logger.exception("Exception occurred in handle_chat")
return "I encountered an error. Please try again."
return "I encountered an error. Please try again."
Loading