Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/google/adk/tools/vertex_ai_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import logging
from typing import Optional
from typing import TYPE_CHECKING

Expand All @@ -25,6 +26,8 @@
from .base_tool import BaseTool
from .tool_context import ToolContext

_logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from ..models import LlmRequest

Expand All @@ -37,6 +40,26 @@ class VertexAiSearchTool(BaseTool):
search_engine_id: The Vertex AI search engine resource ID.
"""

@staticmethod
def _extract_resource_id(resource_path: str, resource_type: str) -> str:
"""Extracts the resource ID from a full resource path.

Args:
resource_path: The full resource path (e.g., "projects/p/locations/l/collections/c/engines/e")
resource_type: The type of resource to extract (e.g., 'engines', 'dataStores')

Returns:
The extracted resource ID
"""
parts = resource_path.split('/')
try:
resource_index = parts.index(resource_type)
if resource_index + 1 < len(parts):
return parts[resource_index + 1]
except ValueError:
pass
return resource_path # Return original if pattern not matched

def __init__(
self,
*,
Expand Down Expand Up @@ -83,6 +106,11 @@ def __init__(
self.data_store_id = data_store_id
self.data_store_specs = data_store_specs
self.search_engine_id = search_engine_id
self._search_engine_name = (
self._extract_resource_id(search_engine_id, 'engines')
if search_engine_id
else None
)
self.filter = filter
self.max_results = max_results
self.bypass_multi_tools_limit = bypass_multi_tools_limit
Expand All @@ -102,6 +130,15 @@ async def process_llm_request(
)
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
_logger.debug(
'Adding Vertex AI Search tool config to LLM request: datastore=%s,'
' engine=%s, filter=%s, max_results=%s, data_store_specs=%s',
self.data_store_id,
self._search_engine_name or self.search_engine_id,
self.filter,
self.max_results,
self.data_store_specs,
)
llm_request.config.tools.append(
types.Tool(
retrieval=types.Retrieval(
Expand Down
146 changes: 132 additions & 14 deletions tests/unittests/tools/test_vertex_ai_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.models.llm_request import LlmRequest
Expand All @@ -24,6 +26,8 @@
from google.genai import types
import pytest

VERTEX_SEARCH_TOOL_LOGGER_NAME = 'google.adk.tools.vertex_ai_search_tool'


async def _create_tool_context() -> ToolContext:
session_service = InMemorySessionService()
Expand Down Expand Up @@ -121,12 +125,32 @@ def test_init_with_data_store_id(self):
tool = VertexAiSearchTool(data_store_id='test_data_store')
assert tool.data_store_id == 'test_data_store'
assert tool.search_engine_id is None
assert tool.data_store_specs is None

def test_init_with_search_engine_id(self):
"""Test initialization with search engine ID."""
tool = VertexAiSearchTool(search_engine_id='test_search_engine')
assert tool.search_engine_id == 'test_search_engine'
assert tool.data_store_id is None
assert tool.data_store_specs is None

def test_init_with_engine_and_specs(self):
"""Test initialization with search engine ID and specs."""
specs = [
types.VertexAISearchDataStoreSpec(
dataStore='projects/p/locations/l/collections/default_collection/dataStores/spec_store_id'
)
]
tool = VertexAiSearchTool(
search_engine_id='projects/p/locations/l/collections/default_collection/engines/test_search_engine',
data_store_specs=specs,
)
assert (
tool.search_engine_id
== 'projects/p/locations/l/collections/default_collection/engines/test_search_engine'
)
assert tool.data_store_id is None
assert tool.data_store_specs == specs

def test_init_with_neither_raises_error(self):
"""Test that initialization without either ID raises ValueError."""
Expand All @@ -146,10 +170,27 @@ def test_init_with_both_raises_error(self):
data_store_id='test_data_store', search_engine_id='test_search_engine'
)

def test_init_with_specs_but_no_engine_raises_error(self):
"""Test that specs without engine ID raises ValueError."""
specs = [
types.VertexAISearchDataStoreSpec(
dataStore='projects/p/locations/l/collections/default_collection/dataStores/spec_store_id'
)
]
with pytest.raises(
ValueError,
match='Either data_store_id or search_engine_id must be specified',
):
VertexAiSearchTool(data_store_specs=specs)

@pytest.mark.asyncio
async def test_process_llm_request_with_simple_gemini_model(self):
async def test_process_llm_request_with_simple_gemini_model(self, caplog):
"""Test processing LLM request with simple Gemini model name."""
tool = VertexAiSearchTool(data_store_id='test_data_store')
caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME)

tool = VertexAiSearchTool(
data_store_id='test_data_store', filter='f', max_results=5
)
tool_context = await _create_tool_context()

llm_request = LlmRequest(
Expand All @@ -162,17 +203,50 @@ async def test_process_llm_request_with_simple_gemini_model(self):

assert llm_request.config.tools is not None
assert len(llm_request.config.tools) == 1
assert llm_request.config.tools[0].retrieval is not None
assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None
retrieval_tool = llm_request.config.tools[0]
assert retrieval_tool.retrieval is not None
assert retrieval_tool.retrieval.vertex_ai_search is not None
assert (
retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store'
)
assert retrieval_tool.retrieval.vertex_ai_search.engine is None
assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f'
assert retrieval_tool.retrieval.vertex_ai_search.max_results == 5
assert retrieval_tool.retrieval.vertex_ai_search.data_store_specs is None

# Check for debug log message and its components
debug_records = [
r for r in caplog.records if 'Adding Vertex AI Search tool config' in r.message
]
assert len(debug_records) == 1
log_message = debug_records[0].getMessage()
assert 'Adding Vertex AI Search tool config to LLM request' in log_message
assert 'datastore=test_data_store' in log_message
assert 'engine=None' in log_message
assert 'filter=f' in log_message
assert 'max_results=5' in log_message
assert 'data_store_specs=None' in log_message

@pytest.mark.asyncio
async def test_process_llm_request_with_path_based_gemini_model(self):
async def test_process_llm_request_with_path_based_gemini_model(self, caplog):
"""Test processing LLM request with path-based Gemini model name."""
tool = VertexAiSearchTool(data_store_id='test_data_store')
caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME)

specs = [
types.VertexAISearchDataStoreSpec(
dataStore='projects/p/locations/l/collections/default_collection/dataStores/spec_store_id'
)
]
tool = VertexAiSearchTool(
search_engine_id='projects/p/locations/l/collections/default_collection/engines/test_engine',
data_store_specs=specs,
filter='f2',
max_results=10,
)
tool_context = await _create_tool_context()

llm_request = LlmRequest(
model='projects/265104255505/locations/us-central1/publishers/google/models/gemini-2.0-flash-001',
model='projects/p/locations/l/publishers/g/models/gemini-2.0-flash-001',
config=types.GenerateContentConfig(),
)

Expand All @@ -182,8 +256,30 @@ async def test_process_llm_request_with_path_based_gemini_model(self):

assert llm_request.config.tools is not None
assert len(llm_request.config.tools) == 1
assert llm_request.config.tools[0].retrieval is not None
assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None
retrieval_tool = llm_request.config.tools[0]
assert retrieval_tool.retrieval is not None
assert retrieval_tool.retrieval.vertex_ai_search is not None
assert retrieval_tool.retrieval.vertex_ai_search.datastore is None
assert (
retrieval_tool.retrieval.vertex_ai_search.engine
== 'projects/p/locations/l/collections/default_collection/engines/test_engine'
)
assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f2'
assert retrieval_tool.retrieval.vertex_ai_search.max_results == 10
assert retrieval_tool.retrieval.vertex_ai_search.data_store_specs == specs

# Check for debug log message and its components
debug_messages = [
r.message for r in caplog.records if r.levelno == logging.DEBUG
]
debug_message = '\n'.join(debug_messages)
assert 'Adding Vertex AI Search tool config to LLM request' in debug_message
assert 'datastore=None' in debug_message
assert 'engine=test_engine' in debug_message
assert 'filter=f2' in debug_message
assert 'max_results=10' in debug_message
assert 'data_store_specs=' in debug_message
assert 'spec_store_id' in debug_message

@pytest.mark.asyncio
async def test_process_llm_request_with_gemini_1_and_other_tools_raises_error(
Expand Down Expand Up @@ -230,7 +326,9 @@ async def test_process_llm_request_with_path_based_gemini_1_and_other_tools_rais
)

llm_request = LlmRequest(
model='projects/265104255505/locations/us-central1/publishers/google/models/gemini-1.5-pro-preview',
model=(
'projects/p/locations/l/publishers/g/models/gemini-1.5-pro-preview'
),
config=types.GenerateContentConfig(tools=[existing_tool]),
)

Expand Down Expand Up @@ -273,7 +371,9 @@ async def test_process_llm_request_with_path_based_non_gemini_model_raises_error
tool = VertexAiSearchTool(data_store_id='test_data_store')
tool_context = await _create_tool_context()

non_gemini_path = 'projects/265104255505/locations/us-central1/publishers/google/models/claude-3-sonnet'
non_gemini_path = (
'projects/p/locations/l/publishers/g/models/claude-3-sonnet'
)
llm_request = LlmRequest(
model=non_gemini_path, config=types.GenerateContentConfig()
)
Expand All @@ -291,9 +391,11 @@ async def test_process_llm_request_with_path_based_non_gemini_model_raises_error

@pytest.mark.asyncio
async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds(
self,
self, caplog
):
"""Test that Gemini 2.x with other tools succeeds."""
caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME)

tool = VertexAiSearchTool(data_store_id='test_data_store')
tool_context = await _create_tool_context()

Expand All @@ -316,5 +418,21 @@ async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds(
assert llm_request.config.tools is not None
assert len(llm_request.config.tools) == 2
assert llm_request.config.tools[0] == existing_tool
assert llm_request.config.tools[1].retrieval is not None
assert llm_request.config.tools[1].retrieval.vertex_ai_search is not None
retrieval_tool = llm_request.config.tools[1]
assert retrieval_tool.retrieval is not None
assert retrieval_tool.retrieval.vertex_ai_search is not None
assert (
retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store'
)

debug_records = [
r for r in caplog.records if 'Adding Vertex AI Search tool config' in r.message
]
assert len(debug_records) == 1
log_message = debug_records[0].getMessage()
assert 'Adding Vertex AI Search tool config to LLM request' in log_message
assert 'datastore=test_data_store' in log_message
assert 'engine=None' in log_message
assert 'filter=None' in log_message
assert 'max_results=None' in log_message
assert 'data_store_specs=None' in log_message
Loading