-
Notifications
You must be signed in to change notification settings - Fork 40
test: Add comprehensive backend unit test suite (Resolves #63) #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import os | ||
| import sys | ||
| import pytest | ||
| from unittest.mock import patch, MagicMock | ||
|
|
||
| # Add the backend directory to sys.path so tests can import from it | ||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | ||
|
|
||
| # Mock heavy dependencies that are not installed in the global env | ||
| # to allow unit tests to run without them. | ||
| mocked_modules = [ | ||
| 'langgraph', | ||
| 'langgraph.graph', | ||
| 'torch', | ||
| 'google', | ||
| 'google.cloud', | ||
| 'google.cloud.aiplatform', | ||
| 'google.cloud.bigquery', | ||
| 'google.genai', | ||
| 'google.genai.types', | ||
| 'transformers', | ||
| 'ks_search_tool' | ||
| ] | ||
|
|
||
| for mod in mocked_modules: | ||
| sys.modules[mod] = MagicMock() | ||
|
|
||
| # agents.py imports END from langgraph.graph | ||
| sys.modules['langgraph.graph'].END = "END" | ||
|
|
||
|
Comment on lines
+25
to
+30
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried moving this to a monkeypatch fixture as suggested, but because these heavy dependencies (langgraph, torch, google-cloud-aiplatform) are intentionally not installed in the testing environment, pytest crashes with ModuleNotFoundError during the initial collection phase before the fixture even executes. Leaving the mock at the module level is required so pytest can successfully collect the test files without those dependencies installed. |
||
| @pytest.fixture(autouse=True) | ||
| def mock_env_vars(): | ||
| env_patcher = patch.dict(os.environ, { | ||
| "GOOGLE_API_KEY": "test-key-123", | ||
| "GCP_PROJECT_ID": "", | ||
| "GEMINI_USE_VERTEX": "false", | ||
| "CORS_ALLOW_ORIGINS": "*", | ||
| "ENVIRONMENT": "test" | ||
| }, clear=False) | ||
|
|
||
| env_patcher.start() | ||
| yield | ||
| env_patcher.stop() | ||
|
|
||
| @pytest.fixture | ||
| def test_client(): | ||
| from main import app | ||
| from fastapi.testclient import TestClient | ||
| return TestClient(app) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from unittest.mock import patch, AsyncMock, MagicMock | ||
|
|
||
| from agents import ( | ||
| _is_more_query, | ||
| QueryIntent, | ||
| fuse_results, | ||
| AgentState, | ||
| NeuroscienceAssistant | ||
| ) | ||
|
|
||
| def test_is_more_query(): | ||
| assert _is_more_query("next 10") == 10 | ||
| assert _is_more_query("show 5") == 5 | ||
| assert _is_more_query("more 20") == 20 | ||
|
|
||
| assert _is_more_query("more") is None | ||
| assert _is_more_query("continue") is None | ||
|
|
||
| assert _is_more_query("find rat electrophysiology") is None | ||
| assert _is_more_query("") is None | ||
|
|
||
| def test_fuse_results(): | ||
| state: AgentState = { | ||
| "session_id": "test_session", | ||
| "query": "rat data", | ||
| "history": [], | ||
| "keywords": [], | ||
| "effective_query": "rat data", | ||
| "intents": [], | ||
| "ks_results": [{"_id": "doc_common", "_score": 10.0}, {"_id": "doc_ks_only", "_score": 5.0}], | ||
| "vector_results": [{"id": "doc_common", "similarity": 0.8}, {"id": "doc_vec_only", "similarity": 0.9}], | ||
| "final_results": [], | ||
| "all_results": [], | ||
| "start_number": 1, | ||
| "previous_text": "", | ||
| "final_response": "", | ||
| } | ||
|
|
||
| new_state = fuse_results(state) | ||
| all_res = new_state["all_results"] | ||
|
|
||
| assert len(all_res) == 3 | ||
|
|
||
| # doc_common score: vector (0.8 * 0.6 = 0.48) + ks (10.0 * 0.4 = 4.0) = 4.48 | ||
| # doc_vec_only score: vector (0.9 * 0.6 = 0.54) + ks (0) = 0.54 | ||
| # doc_ks_only score: vector (0) + ks (5.0 * 0.4 = 2.0) = 2.0 | ||
| doc_ids = [res.get("id") or res.get("_id") for res in all_res] | ||
| assert doc_ids == ["doc_common", "doc_ks_only", "doc_vec_only"] | ||
|
|
||
| def test_neuroscience_assistant_reset(): | ||
| assistant = NeuroscienceAssistant() | ||
|
|
||
| assistant.chat_history["session_123"] = ["User: Hello", "Assistant: Hi"] | ||
| assistant.session_memory["session_123"] = {"page": 1} | ||
|
|
||
| assistant.reset_session("session_123") | ||
|
|
||
| assert "session_123" not in assistant.chat_history | ||
| assert "session_123" not in assistant.session_memory |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from unittest.mock import patch, AsyncMock | ||
|
|
||
| def test_root_endpoint(test_client): | ||
| response = test_client.get("/") | ||
| assert response.status_code == 200 | ||
| assert "KnowledgeSpace AI Backend is running" in response.json()["message"] | ||
|
|
||
| def test_health_check_endpoint(test_client): | ||
| response = test_client.get("/health") | ||
| assert response.status_code == 200 | ||
| assert response.json()["status"] == "healthy" | ||
|
|
||
| @patch("main.asyncio.wait_for") | ||
| def test_api_health_endpoint(mock_wait_for, test_client): | ||
| mock_wait_for.return_value = True | ||
|
|
||
| response = test_client.get("/api/health") | ||
| assert response.status_code == 200 | ||
|
|
||
| data = response.json() | ||
| assert data["status"] == "healthy" | ||
| assert data["components"]["vector_search"] == "enabled" | ||
| assert data["components"]["llm"] == "enabled" # enabled via conftest env patch | ||
|
|
||
| @patch("main.assistant") | ||
| def test_chat_endpoint_success(mock_assistant, test_client): | ||
| mock_assistant.handle_chat = AsyncMock(return_value="Found 3 datasets for rat hippocampus...") | ||
|
|
||
| payload = { | ||
| "query": "find rat hippocampus data", | ||
| "session_id": "session_123", | ||
| "reset": False | ||
| } | ||
|
|
||
| response = test_client.post("/api/chat", json=payload) | ||
|
|
||
| assert response.status_code == 200 | ||
| data = response.json() | ||
| assert data["response"] == "Found 3 datasets for rat hippocampus..." | ||
| assert "process_time" in data["metadata"] | ||
| assert data["metadata"]["session_id"] == "session_123" | ||
| assert data["metadata"]["reset"] is False | ||
|
|
||
| mock_assistant.handle_chat.assert_called_once_with( | ||
| session_id="session_123", | ||
| query="find rat hippocampus data", | ||
| reset=False | ||
| ) | ||
|
|
||
| @patch("main.assistant") | ||
| def test_session_reset_endpoint(mock_assistant, test_client): | ||
| response = test_client.post("/api/session/reset", json={"session_id": "session_456"}) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert response.json()["status"] == "ok" | ||
|
|
||
| mock_assistant.reset_session.assert_called_once_with("session_456") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| from unittest.mock import patch, MagicMock | ||
|
|
||
| from retrieval import VertexRetriever, get_retriever | ||
| from local_retriever import LocalRetriever | ||
|
|
||
| def test_local_retriever(): | ||
| retriever = LocalRetriever() | ||
| assert retriever.is_enabled is True | ||
| assert retriever.search("test query") == [] | ||
|
|
||
| @patch("retrieval.os.getenv") | ||
| @patch("retrieval.aiplatform") | ||
| @patch("retrieval.bigquery") | ||
| @patch("retrieval.AutoTokenizer") | ||
| @patch("retrieval.AutoModel") | ||
| @patch("retrieval.torch.cuda.is_available", return_value=False) | ||
| def test_vertex_retriever_initialization(mock_cuda, mock_model, mock_tokenizer, mock_bq, mock_ai, mock_getenv): | ||
| def mock_env(key, default=""): | ||
| mapping = { | ||
| "GCP_PROJECT_ID": "test-project", | ||
| "GCP_REGION": "us-central1", | ||
| "INDEX_ENDPOINT_ID_FULL": "projects/123/locations/us-central1/indexEndpoints/456", | ||
| "DEPLOYED_INDEX_ID": "test_index", | ||
| } | ||
| return mapping.get(key, default) | ||
|
|
||
| mock_getenv.side_effect = mock_env | ||
|
|
||
| # prevent actual model downloads during test | ||
| mock_model.from_pretrained.return_value.eval.return_value.to.return_value = MagicMock() | ||
|
|
||
| retriever = VertexRetriever() | ||
|
|
||
| assert retriever.is_enabled is True | ||
| mock_ai.init.assert_called_once_with(project="test-project", location="us-central1") | ||
| mock_bq.Client.assert_called_once_with(project="test-project") | ||
| mock_tokenizer.from_pretrained.assert_called_once() | ||
| mock_model.from_pretrained.assert_called_once() | ||
|
|
||
| def test_get_retriever_fallback(): | ||
| # fallback happens because we stripped GCP env vars in conftest | ||
| retriever = get_retriever() | ||
| assert isinstance(retriever, LocalRetriever) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching the typo in the PR description! You are correct that pip install -e .[dev] should be run from the repository root, not the backend directory. However, I have intentionally kept the sys.path.insert in conftest.py as a fallback. This ensures the test suite still runs flawlessly even if a developer hasn't installed the package in editable mode.