Skip to content

Commit 4ca4e2d

Browse files
committed
fix bedrock response
1 parent 41468d6 commit 4ca4e2d

File tree

2 files changed

+54
-64
lines changed

2 files changed

+54
-64
lines changed

aws_lambda_powertools/event_handler/bedrock_agent.py

Lines changed: 31 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
class BedrockResponse:
2525
def __init__(
2626
self,
27-
body: Any,
27+
body: Any = None,
2828
status_code: int = 200,
2929
content_type: str = "application/json",
3030
session_attributes: dict[str, Any] | None = None,
@@ -36,61 +36,38 @@ def __init__(
3636
self.content_type = content_type
3737
self.session_attributes = session_attributes
3838
self.prompt_session_attributes = prompt_session_attributes
39-
40-
if knowledge_bases_configuration is not None:
41-
if not isinstance(knowledge_bases_configuration, list) or not all(
42-
isinstance(item, dict) for item in knowledge_bases_configuration
43-
):
44-
raise ValueError("knowledge_bases_configuration must be a list of dictionaries")
45-
4639
self.knowledge_bases_configuration = knowledge_bases_configuration
4740

48-
def to_dict(self, event: BedrockAgentEvent) -> dict[str, Any]:
49-
result: dict[str, Any] = {
50-
"messageVersion": "1.0",
51-
"response": {
52-
"apiPath": event.api_path,
53-
"actionGroup": event.action_group,
54-
"httpMethod": event.http_method,
55-
"httpStatusCode": self.status_code,
56-
"responseBody": {
57-
self.content_type: {"body": json.dumps(self.body) if isinstance(self.body, dict) else self.body},
58-
},
59-
},
41+
def is_json(self) -> bool:
42+
return self.content_type == "application/json"
43+
44+
def to_dict(self) -> dict:
45+
return {
46+
"body": self.body,
47+
"status_code": self.status_code,
48+
"content_type": self.content_type,
49+
"session_attributes": self.session_attributes,
50+
"prompt_session_attributes": self.prompt_session_attributes,
51+
"knowledge_bases_configuration": self.knowledge_bases_configuration,
6052
}
6153

62-
# Add optional attributes if they exist
63-
if self.session_attributes is not None:
64-
result["sessionAttributes"] = self.session_attributes
65-
66-
if self.prompt_session_attributes is not None:
67-
result["promptSessionAttributes"] = self.prompt_session_attributes
68-
69-
if self.knowledge_bases_configuration is not None:
70-
result["knowledgeBasesConfiguration"] = self.knowledge_bases_configuration
71-
72-
return result
73-
7454

7555
class BedrockResponseBuilder(ResponseBuilder):
76-
"""
77-
Bedrock Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agents.
78-
79-
Since the payload format is different from the standard API Gateway Proxy event, we override the build method.
80-
"""
81-
8256
@override
8357
def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
84-
"""
85-
Build the response dictionary to be returned by the Lambda function.
86-
"""
8758
self._route(event, None)
8859

89-
body = self.response.body
90-
if self.response.is_json() and not isinstance(self.response.body, str):
91-
body = self.serializer(self.response.body)
60+
bedrock_response = None
61+
if isinstance(self.response.body, dict) and "body" in self.response.body:
62+
bedrock_response = BedrockResponse(**self.response.body)
63+
body = bedrock_response.body
64+
else:
65+
body = self.response.body
66+
67+
if self.response.is_json() and not isinstance(body, str):
68+
body = self.serializer(body)
9269

93-
base_response = {
70+
response = {
9471
"messageVersion": "1.0",
9572
"response": {
9673
"actionGroup": event.action_group,
@@ -105,22 +82,16 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
10582
},
10683
}
10784

108-
if isinstance(self.response, BedrockResponse):
109-
self._add_bedrock_specific_configs(base_response)
110-
111-
return base_response
112-
113-
def _add_bedrock_specific_configs(self, response: dict[str, Any]) -> None:
114-
if not isinstance(self.response, BedrockResponse):
115-
return
116-
117-
optional_configs = {
118-
"sessionAttributes": self.response.session_attributes,
119-
"promptSessionAttributes": self.response.prompt_session_attributes,
120-
"knowledgeBasesConfiguration": self.response.knowledge_bases_configuration,
121-
}
85+
# Add Bedrock-specific attributes
86+
if bedrock_response:
87+
if bedrock_response.session_attributes:
88+
response["sessionAttributes"] = bedrock_response.session_attributes
89+
if bedrock_response.prompt_session_attributes:
90+
response["promptSessionAttributes"] = bedrock_response.prompt_session_attributes
91+
if bedrock_response.knowledge_bases_configuration:
92+
response["knowledgeBasesConfiguration"] = bedrock_response.knowledge_bases_configuration
12293

123-
response.update({k: v for k, v in optional_configs.items() if v is not None})
94+
return response
12495

12596

12697
class BedrockAgentResolver(ApiGatewayResolver):

tests/functional/event_handler/_pydantic/test_bedrock_agent.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,35 @@ def test_bedrock_agent_with_bedrock_response():
208208

209209
# WHEN using BedrockResponse
210210
@app.get("/claims", description="Gets claims")
211-
def claims() -> Dict[str, Any]:
211+
def claims():
212212
assert isinstance(app.current_event, BedrockAgentEvent)
213213
assert app.lambda_context == {}
214-
return BedrockResponse(body={"message": "success"}, session_attributes={"last_request": "get_claims"})
214+
return BedrockResponse(
215+
session_attributes={"user_id": "123"},
216+
prompt_session_attributes={"context": "testing"},
217+
knowledge_bases_configuration=[
218+
{
219+
"knowledgeBaseId": "kb-123",
220+
"retrievalConfiguration": {
221+
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
222+
},
223+
},
224+
],
225+
)
215226

216227
result = app(load_event("bedrockAgentEvent.json"), {})
217-
print(result)
218228

219-
# To be implemented: check if session_attributes
220229
assert result["messageVersion"] == "1.0"
221230
assert result["response"]["apiPath"] == "/claims"
222231
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
223232
assert result["response"]["httpMethod"] == "GET"
233+
assert result["sessionAttributes"] == {"user_id": "123"}
234+
assert result["promptSessionAttributes"] == {"context": "testing"}
235+
assert result["knowledgeBasesConfiguration"] == [
236+
{
237+
"knowledgeBaseId": "kb-123",
238+
"retrievalConfiguration": {
239+
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
240+
},
241+
},
242+
]

0 commit comments

Comments
 (0)