Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event_handler): add route-level custom response validation in OpenAPI utility #6341

Open
wants to merge 26 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
85321d1
feat(api-gateway-resolver): Add option for custom response validation…
Feb 28, 2025
aa7cf6f
feat(docs): Added doc for custom response validation error responses.
Feb 28, 2025
cb7fd6d
feat(unit-test): Add tests for custom response validation error.
Feb 28, 2025
1228632
fix: Formatting.
Feb 28, 2025
b95d521
fix(unit-test): fix failed CI.
Feb 28, 2025
f849930
feat(unit-test): add tests for incorrect types and invalid configs
Feb 28, 2025
bafd19c
refactor: rename response_validation_error_http_status to response_va…
amin-farjadi Mar 7, 2025
9b09bb7
refactor(tests): move unit tests into openapi_validation functional t…
amin-farjadi Mar 7, 2025
bbbd989
feat: add route-specific custom response validation and tests
amin-farjadi Mar 7, 2025
ce7be15
fix: except Route implementation
amin-farjadi Mar 18, 2025
95d9aee
fix: put custom_response_validation_http_code before middleware
amin-farjadi Mar 21, 2025
210b765
feat: route's custom response validation must take precedence over ap…
amin-farjadi Mar 23, 2025
575e713
feat: added more tests.
amin-farjadi Mar 23, 2025
440a3f4
refactor: improved error messagee and tests' descriptions.
amin-farjadi Mar 23, 2025
249554f
feat: updated docs.
amin-farjadi Mar 25, 2025
d0eadf0
move veritifcation method of route custom http code to BaseRouter.
amin-farjadi Mar 25, 2025
2316637
Merge branch 'develop' into feature/route-custom-response-validation
amin-farjadi Mar 25, 2025
59bb4aa
fix: add validate function for route http code to APIGatewayResolver …
amin-farjadi Mar 25, 2025
020c973
feat: add custom_response_validation_http_code to the routes of Bedrock
amin-farjadi Mar 25, 2025
5ea8ffa
fix: make mypy happy
amin-farjadi Mar 25, 2025
ac5dbf4
Merge branch 'develop' into feature/route-custom-response-validation
amin-farjadi Mar 25, 2025
d23be99
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Mar 25, 2025
ec113cb
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Mar 27, 2025
5794c27
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Mar 31, 2025
3e5fb6e
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Apr 2, 2025
6d00446
Merge branch 'develop' into feature/route-custom-response-validation
leandrodamascena Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
@@ -319,6 +319,7 @@ def __init__(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: HTTPStatus | None = None,
middlewares: list[Callable[..., Response]] | None = None,
):
"""
@@ -360,8 +361,11 @@ def __init__(
Additional OpenAPI extensions as a dictionary.
deprecated: bool
Whether or not to mark this route as deprecated in the OpenAPI schema
custom_response_validation_http_code: HTTPStatus | None, optional
Whether to have custom http status code for this route if response validation fails
middlewares: list[Callable[..., Response]] | None
The list of route middlewares to be called in order.
# TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover?

"""
self.method = method.upper()
self.path = "/" if path.strip() == "" else path
@@ -397,6 +401,8 @@ def __init__(
# _body_field is used to cache the dependant model for the body field
self._body_field: ModelField | None = None

self.custom_response_validation_http_code = custom_response_validation_http_code

def __call__(
self,
router_middlewares: list[Callable],
@@ -505,7 +511,7 @@ def body_field(self) -> ModelField | None:

return self._body_field

def _get_openapi_path(
def _get_openapi_path( # noqa: PLR0912
self,
*,
dependant: Dependant,
@@ -565,6 +571,14 @@ def _get_openapi_path(
},
}

# Add custom response validation response, if exists
if self.custom_response_validation_http_code:
http_code = self.custom_response_validation_http_code.value
operation_responses[http_code] = {
"description": "Response Validation Error",
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}},
}

# Add the response to the OpenAPI operation
if self.responses:
for status_code in list(self.responses):
@@ -942,6 +956,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
raise NotImplementedError()
@@ -1003,6 +1018,7 @@ def get(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Get route decorator with GET `method`
@@ -1043,6 +1059,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -1062,6 +1079,7 @@ def post(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Post route decorator with POST `method`
@@ -1103,6 +1121,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -1122,6 +1141,7 @@ def put(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Put route decorator with PUT `method`
@@ -1163,6 +1183,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -1182,6 +1203,7 @@ def delete(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Delete route decorator with DELETE `method`
@@ -1222,6 +1244,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -1241,6 +1264,7 @@ def patch(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Patch route decorator with PATCH `method`
@@ -1284,6 +1308,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -1303,6 +1328,7 @@ def head(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Head route decorator with HEAD `method`
@@ -1345,6 +1371,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -2108,6 +2135,29 @@ def swagger_handler():
body=body,
)

def _validate_route_response_validation_error_http_code(
self,
custom_response_validation_http_code: int | HTTPStatus | None,
) -> HTTPStatus | None:
if custom_response_validation_http_code and not self._enable_validation:
msg = (
"'custom_response_validation_http_code' cannot be set for route when enable_validation is False "
"on resolver."
)
raise ValueError(msg)

if (
not isinstance(custom_response_validation_http_code, HTTPStatus)
and custom_response_validation_http_code is not None
):
try:
custom_response_validation_http_code = HTTPStatus(custom_response_validation_http_code)
except ValueError:
msg = f"'{custom_response_validation_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501
raise ValueError(msg) from None

return custom_response_validation_http_code

def route(
self,
rule: str,
@@ -2125,10 +2175,15 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Route decorator includes parameter `method`"""

custom_response_validation_http_code = self._validate_route_response_validation_error_http_code(
custom_response_validation_http_code,
)

def register_resolver(func: AnyCallableT) -> AnyCallableT:
methods = (method,) if isinstance(method, str) else method
logger.debug(f"Adding route using rule {rule} and methods: {','.join(m.upper() for m in methods)}")
@@ -2154,6 +2209,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT:
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -2523,15 +2579,22 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
)

# OpenAPIValidationMiddleware will only raise ResponseValidationError when
# 'self._response_validation_error_http_code' is not None
# 'self._response_validation_error_http_code' is not None or
# when route has custom_response_validation_http_code
if isinstance(exp, ResponseValidationError):
http_code = self._response_validation_error_http_code
# route validation must take precedence over app validation
route_response_validation_http_code = route.custom_response_validation_http_code
http_code = (
route_response_validation_http_code
if route_response_validation_http_code
else self._response_validation_error_http_code
)
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
return self._response_builder_class(
response=Response(
status_code=http_code.value,
content_type=content_types.APPLICATION_JSON,
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
body={"statusCode": http_code, "detail": errors},
),
serializer=self._serializer,
route=route,
@@ -2682,6 +2745,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
def register_route(func: AnyCallableT) -> AnyCallableT:
@@ -2708,6 +2772,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT:
frozen_security,
frozen_openapi_extensions,
deprecated,
custom_response_validation_http_code,
)

# Collate Middleware for routes
@@ -2794,6 +2859,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
# NOTE: see #1552 for more context.
@@ -2813,6 +2879,7 @@ def route(
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

11 changes: 11 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION

if TYPE_CHECKING:
from http import HTTPStatus
from re import Match

from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag
@@ -109,6 +110,7 @@ def get( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
openapi_extensions = None
@@ -129,6 +131,7 @@ def get( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -148,6 +151,7 @@ def post( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
@@ -168,6 +172,7 @@ def post( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -187,6 +192,7 @@ def put( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
@@ -207,6 +213,7 @@ def put( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -226,6 +233,7 @@ def patch( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
):
openapi_extensions = None
@@ -246,6 +254,7 @@ def patch( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

@@ -265,6 +274,7 @@ def delete( # type: ignore[override]
operation_id: str | None = None,
include_in_schema: bool = True,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
@@ -285,6 +295,7 @@ def delete( # type: ignore[override]
security,
openapi_extensions,
deprecated,
custom_response_validation_http_code,
middlewares,
)

Original file line number Diff line number Diff line change
@@ -150,6 +150,7 @@ def _handle_response(self, *, route: Route, response: Response):
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
)

return response
@@ -165,6 +166,7 @@ def _serialize_response(
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
has_route_custom_response_validation: bool = False,
) -> Any:
"""
Serialize the response content according to the field type.
@@ -173,8 +175,16 @@ def _serialize_response(
errors: list[dict[str, Any]] = []
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
# route-level validation must take precedence over app-level
if has_route_custom_response_validation:
raise ResponseValidationError(
errors=_normalize_errors(errors),
body=response_content,
source="route",
)
if self._has_response_validation_error:
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content)
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")

raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

if hasattr(field, "serialize"):
Loading