Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 108 additions & 52 deletions src/google/adk/cli/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def __init__(
self.trigger_sources = trigger_sources
self.default_llm_model = default_llm_model
self.default_app_name = os.getenv("ADK_DEFAULT_APP_NAME")
# Registry of active agent-run tasks keyed by session_id,
# enabling cancellation via the /cancel API endpoint.
self.active_tasks: dict[str, asyncio.Task[Any]] = {}

async def get_runner_async(self, app_name: str) -> Runner:
"""Returns the cached runner for the given app."""
Expand Down Expand Up @@ -1218,6 +1221,7 @@ async def update_session(

return session


@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
response_model_exclude_none=True,
Expand Down Expand Up @@ -1472,6 +1476,7 @@ async def worker():
raise HTTPException(status_code=404, detail=str(e)) from e

worker_task = asyncio.create_task(worker())
self.active_tasks[req.session_id] = worker_task

async def monitor():
try:
Expand Down Expand Up @@ -1502,6 +1507,7 @@ async def monitor():
raise
finally:
monitor_task.cancel()
self.active_tasks.pop(req.session_id, None)

@app.post("/run_sse")
async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
Expand All @@ -1518,11 +1524,6 @@ async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
_set_telemetry_context_if_needed(runner)

# Validate session existence before starting the stream.
# We check directly here instead of eagerly advancing the
# runner's async generator with anext(), because splitting
# generator consumption across two asyncio Tasks (request
# handler vs StreamingResponse) breaks OpenTelemetry context
# detachment.
if not runner.auto_create_session:
session = await self.session_service.get_session(
app_name=req.app_name,
Expand All @@ -1535,59 +1536,81 @@ async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
detail=f"Session not found: {req.session_id}",
)

# Convert the events to properly formatted SSE
async def event_generator():
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(
streaming_mode=stream_mode,
custom_metadata=req.custom_metadata,
),
invocation_id=req.invocation_id,
)
) as agen:
try:
# Use a queue to bridge the producer task (runs the agent) and
# the StreamingResponse consumer (formats SSE). This lets the
# /cancel endpoint cancel the producer task via the active_tasks
# registry.
event_queue: asyncio.Queue[Event | Exception | None] = asyncio.Queue()

async def produce_events() -> None:
try:
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(
streaming_mode=stream_mode,
custom_metadata=req.custom_metadata,
),
invocation_id=req.invocation_id,
)
) as agen:
async for event in agen:
# ADK Web renders artifacts from `actions.artifactDelta`
# during part processing *and* during action processing
# 1) the original event with `artifactDelta` cleared (content)
# 2) a content-less "action-only" event carrying `artifactDelta`
events_to_stream = [event]
if (
not req.function_call_event_id
and event.actions.artifact_delta
and event.content
and event.content.parts
):
content_event = event.model_copy(deep=True)
content_event.actions.artifact_delta = {}
artifact_event = event.model_copy(deep=True)
artifact_event.content = None
events_to_stream = [content_event, artifact_event]

for event_to_stream in events_to_stream:
sse_event = event_to_stream.model_dump_json(
exclude_none=True,
by_alias=True,
)
logger.debug(
"Generated event in agent run streaming: %s", sse_event
)
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
await event_queue.put(event)
except asyncio.CancelledError:
pass
except Exception as e: # pylint: disable=broad-exception-caught
await event_queue.put(e)
finally:
await event_queue.put(None) # sentinel

producer_task = asyncio.create_task(produce_events())
self.active_tasks[req.session_id] = producer_task

async def event_generator():
try:
while True:
item = await event_queue.get()
if item is None:
break
if isinstance(item, Exception):
logger.exception("Error in event_generator: %s", item)
yield f"data: {json.dumps({'error': str(item)})}\n\n"
break

events_to_stream = [item]
if (
not req.function_call_event_id
and item.actions.artifact_delta
and item.content
and item.content.parts
):
content_event = item.model_copy(deep=True)
content_event.actions.artifact_delta = {}
artifact_event = item.model_copy(deep=True)
artifact_event.content = None
events_to_stream = [content_event, artifact_event]

for event_to_stream in events_to_stream:
sse_event = event_to_stream.model_dump_json(
exclude_none=True,
by_alias=True,
)
logger.debug(
"Generated event in agent run streaming: %s", sse_event
)
yield f"data: {sse_event}\n\n"
finally:
if not producer_task.done():
producer_task.cancel()
self.active_tasks.pop(req.session_id, None)

# Returns a streaming response with the proper media type for SSE
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
)

@app.websocket("/run_live")
async def run_agent_live(
websocket: WebSocket,
Expand Down Expand Up @@ -1684,6 +1707,8 @@ async def process_messages():
asyncio.create_task(forward_events()),
asyncio.create_task(process_messages()),
]
# Register under session_id so the /cancel endpoint can cancel them.
self.active_tasks[session_id] = tasks[0]
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_EXCEPTION
)
Expand All @@ -1706,3 +1731,34 @@ async def process_messages():
finally:
for task in pending:
task.cancel()
self.active_tasks.pop(session_id, None)

@app.post(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}:cancel",
)
async def cancel_session(
app_name: str, user_id: str, session_id: str
) -> dict[str, Any]:
"""Cancel an in-progress agent run for the given session.

Looks up the active asyncio.Task for *session_id* in the
server's task registry and cancels it. The running agent will
receive a CancelledError on its next await point (e.g. an LLM
API call or tool invocation), allowing it to stop gracefully.

Returns 404 if no active run is found for the session.
"""
task = self.active_tasks.get(session_id)
if task is None or task.done():
raise HTTPException(
status_code=404,
detail=f"No active run found for session '{session_id}'",
)
task.cancel()
logger.info(
"Cancelled agent run for session %s (app=%s, user=%s)",
session_id,
app_name,
user_id,
)
return {"status": "cancelled", "session_id": session_id}
1 change: 0 additions & 1 deletion src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,6 @@ async def _call_llm_async(
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:

async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm') as span:
# Runs before_model_callback inside the call_llm span so
Expand Down
Loading