Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,39 @@ def _require_llm_creds() -> None:


_GENAI_CLIENT = None
_GENAI_CLIENT_LOCK = asyncio.Lock()


def _get_genai_client():
async def _get_genai_client():
"""
Build a google.genai client for either Vertex (ADC or creds file) or API-key mode.
"""
global _GENAI_CLIENT
if _GENAI_CLIENT is not None:
return _GENAI_CLIENT

if _use_vertex():
project = os.getenv("GCP_PROJECT_ID")
location = os.getenv("GCP_REGION") or "europe-west4"
_ensure_google_creds_for_vertex()
_GENAI_CLIENT = genai.Client(vertexai=True, project=project, location=location)
else:
_GENAI_CLIENT = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
async with _GENAI_CLIENT_LOCK:
#Double-check inside lock
if _GENAI_CLIENT is not None:
return _GENAI_CLIENT

if _use_vertex():
project = os.getenv("GCP_PROJECT_ID")
location = os.getenv("GCP_REGION") or "europe-west4"
_ensure_google_creds_for_vertex()
_GENAI_CLIENT = genai.Client(
vertexai=True,
project=project,
location=location
)

else:
_GENAI_CLIENT = genai.Client(
api_key=os.getenv("GOOGLE_API_KEY")
)

return _GENAI_CLIENT


FLASH_MODEL = os.getenv("GEMINI_FLASH_MODEL", "gemini-2.5-flash")
FLASH_LITE_MODEL = os.getenv("GEMINI_FLASH_LITE_MODEL", "gemini-2.5-flash-lite")

Expand Down Expand Up @@ -124,7 +136,7 @@ async def call_gemini_for_keywords(query: str) -> List[str]:
No local greeting filters — prompt handles exclusions. Minimal trim+dedupe here.
"""
_require_llm_creds()
client = _get_genai_client()
client = await _get_genai_client()
prompt = (
"Extract important search keywords and multi-word phrases from a neuroscience *data* query.\n"
"Return STRICT JSON only:\n"
Expand Down Expand Up @@ -166,7 +178,7 @@ async def call_gemini_rewrite_with_history(query: str, history: List[str]) -> st
Keeps exact tokens and multi-word phrases intact.
"""
_require_llm_creds()
client = _get_genai_client()
client = await _get_genai_client()
last_user_turns = [h for h in history if h.startswith("User: ")]
ctx = "\n".join(last_user_turns[-6:])
prompt = (
Expand Down Expand Up @@ -201,7 +213,7 @@ async def call_gemini_detect_intents(query: str, history: List[str]) -> List[str
- If any data-related tokens exist, prefer data_discovery.
"""
_require_llm_creds()
client = _get_genai_client()
client = await _get_genai_client()
allowed = [i.value for i in QueryIntent]
last_user_turns = [h for h in history if h.startswith("User: ")]
ctx = "\n".join(last_user_turns[-6:])
Expand Down Expand Up @@ -242,7 +254,7 @@ async def call_gemini_for_final_synthesis(
) -> str:

_require_llm_creds()
client = _get_genai_client()
client = await _get_genai_client()

extras = []
if QueryIntent.ACCESS_DOWNLOAD.value in intents:
Expand Down