Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions backend/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def __init__(self):
self.query_char_limit = 8000

# Enable only if everything is present
self.is_enabled = all(
self._is_enabled = all(
[self.project_id, self.region, self.index_endpoint_full, self.deployed_id]
)
if not self.is_enabled:
if not self._is_enabled:
logger.warning(
"Vector search disabled due to incomplete GCP env: "
f"project={bool(self.project_id)}, region={bool(self.region)}, "
Expand All @@ -109,7 +109,7 @@ def __init__(self):
self.bq = bigquery.Client(project=self.project_id)
except Exception as e:
logger.error(f"GCP client initialization failed: {e}")
self.is_enabled = False
self._is_enabled = False
return

try:
Expand All @@ -123,7 +123,11 @@ def __init__(self):
logger.info(f"Vector search initialized on device={self.device} using {self.embed_model_name}")
except Exception as e:
logger.error(f"Embedding model initialization failed: {e}")
self.is_enabled = False
self._is_enabled = False

@property
def is_enabled(self) -> bool:
return getattr(self, '_is_enabled', False)

# Embedding
def _embed(self, text: str) -> List[float]:
Expand Down
49 changes: 49 additions & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import sys
import pytest
from unittest.mock import patch, MagicMock

# Add the backend directory to sys.path so tests can import from it
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

Comment on lines +6 to +8
Copy link
Copy Markdown
Author

@shamsulalam1114 shamsulalam1114 Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching the typo in the PR description! You are correct that pip install -e .[dev] should be run from the repository root, not the backend directory. However, I have intentionally kept the sys.path.insert in conftest.py as a fallback. This ensures the test suite still runs flawlessly even if a developer hasn't installed the package in editable mode.

# Mock heavy dependencies that are not installed in the global env
# to allow unit tests to run without them.
mocked_modules = [
'langgraph',
'langgraph.graph',
'torch',
'google',
'google.cloud',
'google.cloud.aiplatform',
'google.cloud.bigquery',
'google.genai',
'google.genai.types',
'transformers',
'ks_search_tool'
]

for mod in mocked_modules:
sys.modules[mod] = MagicMock()

# agents.py imports END from langgraph.graph
sys.modules['langgraph.graph'].END = "END"

Comment on lines +25 to +30
Copy link
Copy Markdown
Author

@shamsulalam1114 shamsulalam1114 Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried moving this to a monkeypatch fixture as suggested, but because these heavy dependencies (langgraph, torch, google-cloud-aiplatform) are intentionally not installed in the testing environment, pytest crashes with ModuleNotFoundError during the initial collection phase before the fixture even executes. Leaving the mock at the module level is required so pytest can successfully collect the test files without those dependencies installed.

@pytest.fixture(autouse=True)
def mock_env_vars():
env_patcher = patch.dict(os.environ, {
"GOOGLE_API_KEY": "test-key-123",
"GCP_PROJECT_ID": "",
"GEMINI_USE_VERTEX": "false",
"CORS_ALLOW_ORIGINS": "*",
"ENVIRONMENT": "test"
}, clear=False)

env_patcher.start()
yield
env_patcher.stop()

@pytest.fixture
def test_client():
from main import app
from fastapi.testclient import TestClient
return TestClient(app)
59 changes: 59 additions & 0 deletions backend/tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from unittest.mock import patch, AsyncMock, MagicMock

from agents import (
_is_more_query,
QueryIntent,
fuse_results,
AgentState,
NeuroscienceAssistant
)

def test_is_more_query():
assert _is_more_query("next 10") == 10
assert _is_more_query("show 5") == 5
assert _is_more_query("more 20") == 20

assert _is_more_query("more") is None
assert _is_more_query("continue") is None

assert _is_more_query("find rat electrophysiology") is None
assert _is_more_query("") is None

def test_fuse_results():
state: AgentState = {
"session_id": "test_session",
"query": "rat data",
"history": [],
"keywords": [],
"effective_query": "rat data",
"intents": [],
"ks_results": [{"_id": "doc_common", "_score": 10.0}, {"_id": "doc_ks_only", "_score": 5.0}],
"vector_results": [{"id": "doc_common", "similarity": 0.8}, {"id": "doc_vec_only", "similarity": 0.9}],
"final_results": [],
"all_results": [],
"start_number": 1,
"previous_text": "",
"final_response": "",
}

new_state = fuse_results(state)
all_res = new_state["all_results"]

assert len(all_res) == 3

# doc_common score: vector (0.8 * 0.6 = 0.48) + ks (10.0 * 0.4 = 4.0) = 4.48
# doc_vec_only score: vector (0.9 * 0.6 = 0.54) + ks (0) = 0.54
# doc_ks_only score: vector (0) + ks (5.0 * 0.4 = 2.0) = 2.0
doc_ids = [res.get("id") or res.get("_id") for res in all_res]
assert doc_ids == ["doc_common", "doc_ks_only", "doc_vec_only"]

def test_neuroscience_assistant_reset():
assistant = NeuroscienceAssistant()

assistant.chat_history["session_123"] = ["User: Hello", "Assistant: Hi"]
assistant.session_memory["session_123"] = {"page": 1}

assistant.reset_session("session_123")

assert "session_123" not in assistant.chat_history
assert "session_123" not in assistant.session_memory
57 changes: 57 additions & 0 deletions backend/tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest.mock import patch, AsyncMock

def test_root_endpoint(test_client):
response = test_client.get("/")
assert response.status_code == 200
assert "KnowledgeSpace AI Backend is running" in response.json()["message"]

def test_health_check_endpoint(test_client):
response = test_client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"

@patch("main.asyncio.wait_for")
def test_api_health_endpoint(mock_wait_for, test_client):
mock_wait_for.return_value = True

response = test_client.get("/api/health")
assert response.status_code == 200

data = response.json()
assert data["status"] == "healthy"
assert data["components"]["vector_search"] == "enabled"
assert data["components"]["llm"] == "enabled" # enabled via conftest env patch

@patch("main.assistant")
def test_chat_endpoint_success(mock_assistant, test_client):
mock_assistant.handle_chat = AsyncMock(return_value="Found 3 datasets for rat hippocampus...")

payload = {
"query": "find rat hippocampus data",
"session_id": "session_123",
"reset": False
}

response = test_client.post("/api/chat", json=payload)

assert response.status_code == 200
data = response.json()
assert data["response"] == "Found 3 datasets for rat hippocampus..."
assert "process_time" in data["metadata"]
assert data["metadata"]["session_id"] == "session_123"
assert data["metadata"]["reset"] is False

mock_assistant.handle_chat.assert_called_once_with(
session_id="session_123",
query="find rat hippocampus data",
reset=False
)

@patch("main.assistant")
def test_session_reset_endpoint(mock_assistant, test_client):
response = test_client.post("/api/session/reset", json={"session_id": "session_456"})

assert response.status_code == 200
assert response.json()["status"] == "ok"

mock_assistant.reset_session.assert_called_once_with("session_456")
43 changes: 43 additions & 0 deletions backend/tests/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import patch, MagicMock

from retrieval import VertexRetriever, get_retriever
from local_retriever import LocalRetriever

def test_local_retriever():
retriever = LocalRetriever()
assert retriever.is_enabled is True
assert retriever.search("test query") == []

@patch("retrieval.os.getenv")
@patch("retrieval.aiplatform")
@patch("retrieval.bigquery")
@patch("retrieval.AutoTokenizer")
@patch("retrieval.AutoModel")
@patch("retrieval.torch.cuda.is_available", return_value=False)
def test_vertex_retriever_initialization(mock_cuda, mock_model, mock_tokenizer, mock_bq, mock_ai, mock_getenv):
def mock_env(key, default=""):
mapping = {
"GCP_PROJECT_ID": "test-project",
"GCP_REGION": "us-central1",
"INDEX_ENDPOINT_ID_FULL": "projects/123/locations/us-central1/indexEndpoints/456",
"DEPLOYED_INDEX_ID": "test_index",
}
return mapping.get(key, default)

mock_getenv.side_effect = mock_env

# prevent actual model downloads during test
mock_model.from_pretrained.return_value.eval.return_value.to.return_value = MagicMock()

retriever = VertexRetriever()

assert retriever.is_enabled is True
mock_ai.init.assert_called_once_with(project="test-project", location="us-central1")
mock_bq.Client.assert_called_once_with(project="test-project")
mock_tokenizer.from_pretrained.assert_called_once()
mock_model.from_pretrained.assert_called_once()

def test_get_retriever_fallback():
# fallback happens because we stripped GCP env vars in conftest
retriever = get_retriever()
assert isinstance(retriever, LocalRetriever)