Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from google.cloud.aiplatform import initializer
from vertexai.preview import reasoning_engines
from vertexai.agent_engines import _utils

import asyncio
import pytest


Expand Down Expand Up @@ -111,6 +113,33 @@ def run(self, *args, **kwargs):
}
)

async def run_async(self, *args, **kwargs):
from google.adk.events import event

yield event.Event(
**{
"author": "currency_exchange_agent",
"content": {
"parts": [
{
"function_call": {
"args": {
"currency_date": "2025-04-03",
"currency_from": "USD",
"currency_to": "SEK",
},
"id": "af-c5a57692-9177-4091-a3df-098f834ee849",
"name": "get_exchange_rate",
}
}
],
"role": "model",
},
"id": "9aaItGK9",
"invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7",
}
)


@pytest.mark.usefixtures("google_auth_mock")
class TestAdkApp:
Expand Down Expand Up @@ -195,6 +224,45 @@ def test_stream_query_with_content(self):
)
assert len(events) == 1

@pytest.mark.asyncio
async def test_async_stream_query(self):
app = reasoning_engines.AdkApp(
agent=Agent(name="test_agent", model=_TEST_MODEL)
)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
events = []
async for event in app.async_stream_query(
user_id="test_user_id",
message="test message",
):
events.append(event)
assert len(events) == 1

@pytest.mark.asyncio
async def test_async_stream_query_with_content(self):
app = reasoning_engines.AdkApp(
agent=Agent(name="test_agent", model=_TEST_MODEL)
)
assert app._tmpl_attrs.get("runner") is None
app.set_up()
app._tmpl_attrs["runner"] = _MockRunner()
events = []
async for event in app.async_stream_query(
user_id="test_user_id",
message=types.Content(
role="user",
parts=[
types.Part(
text="test message with content",
)
],
).model_dump(),
):
events.append(event)
assert len(events) == 1

def test_streaming_agent_run_with_events(self):
app = reasoning_engines.AdkApp(
agent=Agent(name="test_agent", model=_TEST_MODEL)
Expand Down Expand Up @@ -322,3 +390,27 @@ def test_raise_get_session_not_found_error(self):
user_id="non_existent_user",
session_id="test_session_id",
)

def test_stream_query_invalid_message_type(self):
app = reasoning_engines.AdkApp(
agent=Agent(name="test_agent", model=_TEST_MODEL)
)
with pytest.raises(
TypeError,
match="message must be a string or a dictionary representing a Content object.",
):
list(app.stream_query(user_id="test_user_id", message=123))

@pytest.mark.asyncio
async def test_async_stream_query_invalid_message_type(self):
app = reasoning_engines.AdkApp(
agent=Agent(name="test_agent", model=_TEST_MODEL)
)
with pytest.raises(
TypeError,
match="message must be a string or a dictionary representing a Content object.",
):
async for _ in app.async_stream_query(
user_id="test_user_id", message=123
):
pass
Loading