Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions python/packages/kagent-core/src/kagent/core/a2a/_task_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging

import httpx
from a2a.server.tasks import TaskStore
Expand All @@ -8,6 +9,8 @@

from kagent.core.a2a import read_metadata_value

logger = logging.getLogger(__name__)


class KAgentTaskResponse(BaseModel):
"""Wrapper for KAgent controller API responses.
Expand All @@ -25,8 +28,17 @@ class KAgentTaskResponse(BaseModel):
class KAgentTaskStore(TaskStore):
"""
A task store that persists A2A tasks to KAgent via REST API.

Transient transport errors (idle keep-alive connections reset by a service
mesh, controller pod restarts, etc.) are handled transparently: each HTTP
operation is retried once after closing and re-opening the underlying
connection. Non-transport HTTP errors (4xx/5xx) are surfaced immediately
without retrying so real failures are never swallowed.
"""

# Maximum number of automatic retries for transient transport errors.
_MAX_RETRIES = 1

def __init__(self, client: httpx.AsyncClient):
"""Initialize the task store.

Expand All @@ -46,6 +58,46 @@ def _clean_partial_events(self, history: list[Message]) -> list[Message]:
"""Remove partial streaming events from history."""
return [item for item in history if not self._is_partial_event(item)]

async def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
"""Execute an HTTP request, retrying once on transient transport errors.


Args:
method: HTTP method string ("GET", "POST", "DELETE", ...)
url: Request URL (relative to the client's base_url)
**kwargs: Extra keyword arguments forwarded to httpx.AsyncClient.request

Returns:
The successful httpx.Response.

Raises:
httpx.TransportError: If the transport error persists after all retries.
httpx.HTTPStatusError: Propagated immediately without retrying.
"""
last_exc: httpx.TransportError | None = None

for attempt in range(self._MAX_RETRIES + 1):
try:
response = await self.client.request(method, url, **kwargs)
return response
except httpx.TransportError as exc:
last_exc = exc
logger.warning(
"TransportError on %s %s (attempt %d/%d): %s — will retry with a new connection",
method,
url,
attempt + 1,
self._MAX_RETRIES + 1,
exc,
)

# Don't close the shared AsyncClient here: it is reused across the process.
# Just retry once; httpx will establish a new connection on the next request.

# All retries exhausted — re-raise so the caller gets a clear error
# instead of a silent drop.
raise last_exc # type: ignore[misc]

@override
async def save(self, task: Task, context=None) -> None:
"""Save a task to KAgent.
Expand All @@ -59,13 +111,18 @@ async def save(self, task: Task, context=None) -> None:
context: Server call context (unused, for a2a-sdk 0.3+ compatibility)

Raises:
httpx.HTTPStatusError: If the API request fails
httpx.HTTPStatusError: If the API request fails with a non-2xx status.
httpx.TransportError: If a transport error persists after retries.
"""
# Clean any partial events from history before saving
history = task.history or []
task.history = self._clean_partial_events(history)

response = await self.client.post("/api/tasks", json=task.model_dump(mode="json"))
response = await self._request_with_retry(
"POST",
"/api/tasks",
json=task.model_dump(mode="json"),
)
response.raise_for_status()

# Signal that save completed (event-based sync)
Expand All @@ -84,9 +141,10 @@ async def get(self, task_id: str, context=None) -> Task | None:
The task if found, None otherwise

Raises:
httpx.HTTPStatusError: If the API request fails (except 404)
httpx.HTTPStatusError: If the API request fails (except 404).
httpx.TransportError: If a transport error persists after retries.
"""
response = await self.client.get(f"/api/tasks/{task_id}")
response = await self._request_with_retry("GET", f"/api/tasks/{task_id}")
if response.status_code == 404:
return None
response.raise_for_status()
Expand All @@ -104,9 +162,10 @@ async def delete(self, task_id: str, context=None) -> None:
context: Server call context (unused, for a2a-sdk 0.3+ compatibility)

Raises:
httpx.HTTPStatusError: If the API request fails
httpx.HTTPStatusError: If the API request fails.
httpx.TransportError: If a transport error persists after retries.
"""
response = await self.client.delete(f"/api/tasks/{task_id}")
response = await self._request_with_retry("DELETE", f"/api/tasks/{task_id}")
response.raise_for_status()

async def wait_for_save(self, task_id: str, timeout: float = 5.0) -> None:
Expand Down
Loading