Skip to content

feat(bedrock_agents): add optional fields to response payload #6336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
Response,
)
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver, BedrockResponse
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
from aws_lambda_powertools.event_handler.lambda_function_url import (
LambdaFunctionUrlResolver,
@@ -26,6 +26,7 @@
"ALBResolver",
"ApiGatewayResolver",
"BedrockAgentResolver",
"BedrockResponse",
"CORSConfig",
"LambdaFunctionUrlResolver",
"Response",
60 changes: 48 additions & 12 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
@@ -73,6 +73,7 @@
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
_ROUTE_REGEX = "^{}$"
_JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=Encoder)
_DEFAULT_CONTENT_TYPE = "application/json"

ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
ResponseT = TypeVar("ResponseT")
@@ -255,6 +256,35 @@ def build_allow_methods(methods: set[str]) -> str:
return ",".join(sorted(methods))


class BedrockResponse(Generic[ResponseT]):
"""
Contains the response body, status code, content type, and optional attributes
for session management and knowledge base configuration.
"""

def __init__(
self,
body: Any = None,
status_code: int = 200,
content_type: str = _DEFAULT_CONTENT_TYPE,
session_attributes: dict[str, Any] | None = None,
prompt_session_attributes: dict[str, Any] | None = None,
knowledge_bases_configuration: list[dict[str, Any]] | None = None,
) -> None:
self.body = body
self.status_code = status_code
self.content_type = content_type
self.session_attributes = session_attributes
self.prompt_session_attributes = prompt_session_attributes
self.knowledge_bases_configuration = knowledge_bases_configuration

def is_json(self) -> bool:
"""
Returns True if the response is JSON, based on the Content-Type.
"""
return True


class Response(Generic[ResponseT]):
"""Response data class that provides greater control over what is returned from the proxy event"""

@@ -300,7 +330,7 @@ def is_json(self) -> bool:
content_type = self.headers.get("Content-Type", "")
if isinstance(content_type, list):
content_type = content_type[0]
return content_type.startswith("application/json")
return content_type.startswith(_DEFAULT_CONTENT_TYPE)


class Route:
@@ -572,7 +602,7 @@ def _get_openapi_path(
operation_responses: dict[int, OpenAPIResponse] = {
422: {
"description": "Validation Error",
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}},
"content": {_DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}},
},
}

@@ -581,7 +611,9 @@ def _get_openapi_path(
http_code = self.custom_response_validation_http_code.value
operation_responses[http_code] = {
"description": "Response Validation Error",
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}},
"content": {
_DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}},
},
}
# Add model definition
definitions["ResponseValidationError"] = response_validation_error_response_definition
@@ -594,7 +626,7 @@ def _get_openapi_path(
# Case 1: there is not 'content' key
if "content" not in response:
response["content"] = {
"application/json": self._openapi_operation_return(
_DEFAULT_CONTENT_TYPE: self._openapi_operation_return(
param=dependant.return_param,
model_name_map=model_name_map,
field_mapping=field_mapping,
@@ -645,7 +677,7 @@ def _get_openapi_path(
# Add the response schema to the OpenAPI 200 response
operation_responses[200] = {
"description": self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
"content": {"application/json": response_schema},
"content": {_DEFAULT_CONTENT_TYPE: response_schema},
}

operation["responses"] = operation_responses
@@ -1474,7 +1506,10 @@ def __call__(self, app: ApiGatewayResolver) -> dict | tuple | Response:
return self.current_middleware(app, self.next_middleware)


def _registered_api_adapter(app: ApiGatewayResolver, next_middleware: Callable[..., Any]) -> dict | tuple | Response:
def _registered_api_adapter(
app: ApiGatewayResolver,
next_middleware: Callable[..., Any],
) -> dict | tuple | Response | BedrockResponse:
"""
Calls the registered API using the "_route_args" from the Resolver context to ensure the last call
in the chain will match the API route function signature and ensure that Powertools passes the API
@@ -1632,7 +1667,7 @@ def _add_resolver_response_validation_error_response_to_route(
response_validation_error_response = {
"description": "Response Validation Error",
"content": {
"application/json": {
_DEFAULT_CONTENT_TYPE: {
"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"},
},
},
@@ -2151,7 +2186,7 @@ def swagger_handler():
if query_params.get("format") == "json":
return Response(
status_code=200,
content_type="application/json",
content_type=_DEFAULT_CONTENT_TYPE,
body=escaped_spec,
)

@@ -2538,7 +2573,7 @@ def _call_route(self, route: Route, route_arguments: dict[str, str]) -> Response
self._reset_processed_stack()

return self._response_builder_class(
response=self._to_response(
response=self._to_response( # type: ignore[arg-type]
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
),
serializer=self._serializer,
@@ -2627,7 +2662,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild

return None

def _to_response(self, result: dict | tuple | Response) -> Response:
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
"""Convert the route's result to a Response
3 main result types are supported:
@@ -2638,7 +2673,7 @@ def _to_response(self, result: dict | tuple | Response) -> Response:
- Response: returned as is, and allows for more flexibility
"""
status_code = HTTPStatus.OK
if isinstance(result, Response):
if isinstance(result, (Response, BedrockResponse)):
return result
elif isinstance(result, tuple) and len(result) == 2:
# Unpack result dict and status code from tuple
@@ -2971,8 +3006,9 @@ def _get_base_path(self) -> str:
# ALB doesn't have a stage variable, so we just return an empty string
return ""

# BedrockResponse is not used here but adding the same signature to keep strong typing
@override
def _to_response(self, result: dict | tuple | Response) -> Response:
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
"""Convert the route's result to a Response
ALB requires a non-null body otherwise it converts as HTTP 5xx
19 changes: 15 additions & 4 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
from aws_lambda_powertools.event_handler import ApiGatewayResolver
from aws_lambda_powertools.event_handler.api_gateway import (
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
BedrockResponse,
ProxyEventType,
ResponseBuilder,
)
@@ -32,14 +33,11 @@ class BedrockResponseBuilder(ResponseBuilder):

@override
def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
self._route(event, None)

body = self.response.body
if self.response.is_json() and not isinstance(self.response.body, str):
body = self.serializer(self.response.body)

return {
response = {
"messageVersion": "1.0",
"response": {
"actionGroup": event.action_group,
@@ -54,6 +52,19 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
},
}

# Add Bedrock-specific attributes
if isinstance(self.response, BedrockResponse):
if self.response.session_attributes:
response["sessionAttributes"] = self.response.session_attributes

if self.response.prompt_session_attributes:
response["promptSessionAttributes"] = self.response.prompt_session_attributes

if self.response.knowledge_bases_configuration:
response["knowledgeBasesConfiguration"] = self.response.knowledge_bases_configuration

return response


class BedrockAgentResolver(ApiGatewayResolver):
"""Bedrock Agent Resolver
11 changes: 11 additions & 0 deletions docs/core/event_handler/bedrock_agents.md
Original file line number Diff line number Diff line change
@@ -323,6 +323,17 @@ You can enable user confirmation with Bedrock Agents to have your application as

1. Add an openapi extension

### Fine grained responses

???+ info "Note"
The default response only includes the essential fields to keep the payload size minimal, as AWS Lambda has a maximum response size of 25 KB.

You can use `BedrockResponse` class to add additional fields as needed, such as [session attributes, prompt session attributes, and knowledge base configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-response){target="_blank"}.

```python title="working_with_bedrockresponse.py" title="Customzing your Bedrock Response" hl_lines="5 16"
--8<-- "examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py"
```

## Testing your code

Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request:
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from http import HTTPStatus

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import BedrockAgentResolver
from aws_lambda_powertools.event_handler.api_gateway import BedrockResponse
from aws_lambda_powertools.utilities.typing import LambdaContext

tracer = Tracer()
logger = Logger()
app = BedrockAgentResolver()


@app.get("/return_with_session", description="Returns a hello world with session attributes")
@tracer.capture_method
def hello_world():
return BedrockResponse(
status_code=HTTPStatus.OK.value,
body={"message": "Hello from Bedrock!"},
session_attributes={"user_id": "123"},
prompt_session_attributes={"context": "testing"},
knowledge_bases_configuration=[
{
"knowledgeBaseId": "kb-123",
"retrievalConfiguration": {
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
},
},
],
)


@logger.inject_lambda_context
@tracer.capture_lambda_handler
def lambda_handler(event: dict, context: LambdaContext):
return app.resolve(event, context)
129 changes: 128 additions & 1 deletion tests/functional/event_handler/_pydantic/test_bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import pytest
from typing_extensions import Annotated

from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types
from aws_lambda_powertools.event_handler import BedrockAgentResolver, BedrockResponse, Response, content_types
from aws_lambda_powertools.event_handler.openapi.params import Body
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
from tests.functional.utils import load_event
@@ -202,6 +202,133 @@ def handler() -> Optional[Dict]:
assert schema.get("openapi") == "3.0.3"


def test_bedrock_agent_with_bedrock_response():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()

# WHEN using BedrockResponse
@app.get("/claims", description="Gets claims")
def claims():
assert isinstance(app.current_event, BedrockAgentEvent)
assert app.lambda_context == {}
return BedrockResponse(
session_attributes={"user_id": "123"},
prompt_session_attributes={"context": "testing"},
knowledge_bases_configuration=[
{
"knowledgeBaseId": "kb-123",
"retrievalConfiguration": {
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
},
},
],
)

result = app(load_event("bedrockAgentEvent.json"), {})

assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["sessionAttributes"] == {"user_id": "123"}
assert result["promptSessionAttributes"] == {"context": "testing"}
assert result["knowledgeBasesConfiguration"] == [
{
"knowledgeBaseId": "kb-123",
"retrievalConfiguration": {
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
},
},
]


def test_bedrock_agent_with_empty_bedrock_response():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()

@app.get("/claims", description="Gets claims")
def claims():
return BedrockResponse(body={"message": "test"})

# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})

# THEN process event correctly without optional attributes
assert result["messageVersion"] == "1.0"
assert result["response"]["httpStatusCode"] == 200
assert "sessionAttributes" not in result
assert "promptSessionAttributes" not in result
assert "knowledgeBasesConfiguration" not in result


def test_bedrock_agent_with_partial_bedrock_response():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()

@app.get("/claims", description="Gets claims")
def claims() -> Dict[str, Any]:
return BedrockResponse(
body={"message": "test"},
session_attributes={"user_id": "123"},
# Only include session_attributes to test partial response
)

# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})

# THEN process event correctly with only session_attributes
assert result["messageVersion"] == "1.0"
assert result["response"]["httpStatusCode"] == 200
assert result["sessionAttributes"] == {"user_id": "123"}
assert "promptSessionAttributes" not in result
assert "knowledgeBasesConfiguration" not in result


def test_bedrock_agent_with_string():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()

@app.get("/claims", description="Gets claims")
def claims() -> str:
return "a"

# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})

# THEN process event correctly with only session_attributes
assert result["messageVersion"] == "1.0"
assert result["response"]["httpStatusCode"] == 200


def test_bedrock_agent_with_different_attributes_combination():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()

@app.get("/claims", description="Gets claims")
def claims() -> Dict[str, Any]:
return BedrockResponse(
body={"message": "test"},
prompt_session_attributes={"context": "testing"},
knowledge_bases_configuration=[
{
"knowledgeBaseId": "kb-123",
"retrievalConfiguration": {"vectorSearchConfiguration": {"numberOfResults": 3}},
},
],
# Omit session_attributes to test different combination
)

# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})

# THEN process event correctly with specific attributes
assert result["messageVersion"] == "1.0"
assert result["response"]["httpStatusCode"] == 200
assert "sessionAttributes" not in result
assert result["promptSessionAttributes"] == {"context": "testing"}
assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb-123"


def test_bedrock_resolver_with_openapi_extensions():
# GIVEN BedrockAgentResolver is initialized with enable_validation=True
app = BedrockAgentResolver(enable_validation=True)