Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/google/adk/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,14 @@ class RunConfig(BaseModel):
)
"""

model_input_context: list[types.Content] | None = None
"""Transient context to include in the model input for this invocation.

The Runner does not persist these contents to the session. They are only
added to the LLM request assembled for the current invocation, which lets
callers provide per-turn context without changing the conversation history.
"""

@model_validator(mode='before')
@classmethod
def check_for_deprecated_save_live_audio(cls, data: Any) -> Any:
Expand Down
30 changes: 30 additions & 0 deletions src/google/adk/flows/llm_flows/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ async def run_async(
preserve_function_call_ids=preserve_function_call_ids,
)

if (
invocation_context.run_config
and invocation_context.run_config.model_input_context
):
_add_model_input_context_to_user_content(
invocation_context,
llm_request,
copy.deepcopy(invocation_context.run_config.model_input_context),
)

# Add instruction-related contents to proper position in conversation
await _add_instructions_to_user_content(
invocation_context, llm_request, instruction_related_contents
Expand Down Expand Up @@ -845,6 +855,26 @@ def _content_contains_function_response(content: types.Content) -> bool:
return False


def _add_model_input_context_to_user_content(
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_input_context: list[types.Content],
) -> None:
"""Insert transient model input context before the invocation user content."""
if not model_input_context:
return

insert_index = 0
user_content = invocation_context.user_content
if user_content:
for i in range(len(llm_request.contents) - 1, -1, -1):
if llm_request.contents[i] == user_content:
insert_index = i
break

llm_request.contents[insert_index:insert_index] = model_input_context


async def _add_instructions_to_user_content(
invocation_context: InvocationContext,
llm_request: LlmRequest,
Expand Down
148 changes: 148 additions & 0 deletions tests/unittests/agents/test_llm_agent_include_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Unit tests for LlmAgent include_contents field behavior."""

from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.agents.sequential_agent import SequentialAgent
from google.genai import types
import pytest
Expand Down Expand Up @@ -189,6 +190,153 @@ def simple_tool(message: str) -> dict:
assert len(mock_model.requests[0].config.tools) > 0


def test_model_input_context_is_sent_to_model_without_persisting_to_session():
mock_model = testing_utils.MockModel.create(responses=["Answer"])
agent = LlmAgent(name="test_agent", model=mock_model)
runner = testing_utils.InMemoryRunner(agent)
session = runner.session

list(
runner.runner.run(
user_id=session.user_id,
session_id=session.id,
new_message=testing_utils.get_user_content("Question"),
run_config=RunConfig(
model_input_context=[
types.UserContent("Relevant context for this turn")
]
),
)
)

assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [
("user", "Relevant context for this turn"),
("user", "Question"),
]
assert testing_utils.simplify_events(runner.session.events) == [
("user", "Question"),
("test_agent", "Answer"),
]


def test_model_input_context_stays_before_user_message_after_tool_call():
def simple_tool(message: str) -> dict:
return {"result": f"Tool processed: {message}"}

mock_model = testing_utils.MockModel.create(
responses=[
types.Part.from_function_call(
name="simple_tool", args={"message": "payload"}
),
"Answer",
]
)
agent = LlmAgent(name="test_agent", model=mock_model, tools=[simple_tool])
runner = testing_utils.InMemoryRunner(agent)
session = runner.session

list(
runner.runner.run(
user_id=session.user_id,
session_id=session.id,
new_message=testing_utils.get_user_content("Question"),
run_config=RunConfig(
model_input_context=[
types.UserContent("Relevant context for this turn")
]
),
)
)

assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [
("user", "Relevant context for this turn"),
("user", "Question"),
]
assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [
("user", "Relevant context for this turn"),
("user", "Question"),
(
"model",
types.Part.from_function_call(
name="simple_tool", args={"message": "payload"}
),
),
(
"user",
types.Part.from_function_response(
name="simple_tool",
response={"result": "Tool processed: payload"},
),
),
]
assert testing_utils.simplify_events(runner.session.events) == [
("user", "Question"),
(
"test_agent",
types.Part.from_function_call(
name="simple_tool", args={"message": "payload"}
),
),
(
"test_agent",
types.Part.from_function_response(
name="simple_tool",
response={"result": "Tool processed: payload"},
),
),
("test_agent", "Answer"),
]


def test_model_input_context_with_include_contents_none_sub_agent():
agent1_model = testing_utils.MockModel.create(
responses=["Agent1 response: XYZ"]
)
agent1 = LlmAgent(name="agent1", model=agent1_model)

agent2_model = testing_utils.MockModel.create(
responses=["Agent2 final response"]
)
agent2 = LlmAgent(
name="agent2",
model=agent2_model,
include_contents="none",
)
sequential_agent = SequentialAgent(
name="sequential_test_agent", sub_agents=[agent1, agent2]
)
runner = testing_utils.InMemoryRunner(sequential_agent)
session = runner.session

list(
runner.runner.run(
user_id=session.user_id,
session_id=session.id,
new_message=testing_utils.get_user_content("Original user request"),
run_config=RunConfig(
model_input_context=[
types.UserContent("Relevant context for this turn")
]
),
)
)

assert testing_utils.simplify_contents(agent1_model.requests[0].contents) == [
("user", "Relevant context for this turn"),
("user", "Original user request"),
]
assert testing_utils.simplify_contents(agent2_model.requests[0].contents) == [
("user", "Relevant context for this turn"),
(
"user",
[
types.Part(text="For context:"),
types.Part(text="[agent1] said: Agent1 response: XYZ"),
],
),
]


@pytest.mark.asyncio
async def test_include_contents_none_sequential_agents():
"""Test include_contents='none' with sequential agents."""
Expand Down
8 changes: 8 additions & 0 deletions tests/unittests/agents/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ def test_avatar_config_with_name():
assert run_config.avatar_config == avatar_config
assert run_config.avatar_config.avatar_name == "test_avatar"
assert run_config.avatar_config.customized_avatar is None


def test_model_input_context_accepts_transient_contents():
context_content = types.UserContent("Relevant context for this turn")

run_config = RunConfig(model_input_context=[context_content])

assert run_config.model_input_context == [context_content]
Loading