Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 224 additions & 10 deletions samcli/commands/local/lib/swagger/parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Handles Swagger Parsing"""

import logging
from typing import Dict, List, Union
from typing import Any, Dict, List, Optional, Set, Union

from samcli.commands.local.lib.swagger.integration_uri import IntegrationType, LambdaUri
from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator
Expand All @@ -17,6 +17,9 @@

LOG = logging.getLogger(__name__)

# Maximum depth for resolving nested $ref to prevent infinite recursion
_MAX_REF_RESOLUTION_DEPTH = 10


class SwaggerParser:
_AUTHORIZER_KEY = "x-amazon-apigateway-authorizer"
Expand Down Expand Up @@ -64,6 +67,145 @@ def get_binary_media_types(self):
"""
return self.swagger.get(self._BINARY_MEDIA_TYPES_EXTENSION_KEY) or []

def _resolve_ref(
self,
ref_value: Any,
visited: Optional[Set[str]] = None,
depth: int = 0,
) -> Optional[Dict]:
"""
Resolve a JSON Reference ($ref) within the Swagger document.

This method handles:
- Local JSON Pointer references (#/path/to/object)
- Nested $ref (a $ref pointing to another $ref)
- Circular $ref detection
- Depth limiting to prevent infinite recursion

Parameters
----------
ref_value : str
The $ref value, e.g., "#/components/x-amazon-apigateway-integrations/lambda"
visited : Set[str], optional
Set of already visited $ref values for circular reference detection
depth : int, optional
Current recursion depth

Returns
-------
dict or None
The resolved object, or None if not found or if the reference is invalid/unsupported
"""
if visited is None:
visited = set()

# Check for invalid or empty ref
if not ref_value or not isinstance(ref_value, str):
LOG.debug("Invalid $ref value: %s", ref_value)
return None

# Check for external references (not supported in local mode)
if not ref_value.startswith("#/"):
if ref_value.startswith("http://") or ref_value.startswith("https://"):
LOG.warning(
"External URL $ref '%s' is not supported in SAM CLI local mode. "
"Consider inlining the referenced content.",
ref_value,
)
elif ref_value.startswith("./") or ref_value.startswith("../") or "/" in ref_value:
LOG.warning(
"External file $ref '%s' is not supported in SAM CLI local mode. "
"Consider inlining the referenced content or using a single OpenAPI file.",
ref_value,
)
else:
LOG.debug("Unsupported $ref format: %s", ref_value)
return None

# Check for circular reference
if ref_value in visited:
LOG.warning("Circular $ref detected: %s. Skipping to prevent infinite loop.", ref_value)
return None

# Check depth limit
if depth >= _MAX_REF_RESOLUTION_DEPTH:
LOG.warning(
"Maximum $ref resolution depth (%d) exceeded for '%s'. "
"This may indicate deeply nested or circular references.",
_MAX_REF_RESOLUTION_DEPTH,
ref_value,
)
return None

# Mark this ref as visited
visited.add(ref_value)

# Remove the leading "#/" and split by "/"
# Handle URL-encoded characters in path (e.g., ~0 for ~, ~1 for /)
path_parts = ref_value[2:].split("/")
decoded_parts = []
for part in path_parts:
# JSON Pointer escaping: ~1 = /, ~0 = ~
decoded_part = part.replace("~1", "/").replace("~0", "~")
decoded_parts.append(decoded_part)

current = self.swagger
for part in decoded_parts:
if isinstance(current, dict) and part in current:
current = current[part]
else:
LOG.debug("Unable to resolve $ref '%s': path component '%s' not found", ref_value, part)
return None

# If the resolved value itself contains a $ref, resolve it recursively
if isinstance(current, dict) and "$ref" in current:
nested_ref = current.get("$ref")
LOG.debug("Found nested $ref '%s' while resolving '%s'", nested_ref, ref_value)
return self._resolve_ref(nested_ref, visited, depth + 1)

return current if isinstance(current, dict) else None

def _resolve_object_refs(self, obj: Any, visited: Optional[Set[str]] = None, depth: int = 0) -> Any:
"""
Recursively resolve all $ref in an object.

This is useful for objects that may have $ref at various levels,
not just at the top level.

Parameters
----------
obj : Any
The object to resolve $refs in
visited : Set[str], optional
Set of already visited $ref values
depth : int, optional
Current recursion depth

Returns
-------
Any
The object with all $refs resolved
"""
if visited is None:
visited = set()

if depth >= _MAX_REF_RESOLUTION_DEPTH:
return obj

if isinstance(obj, dict):
if "$ref" in obj:
ref_value = obj.get("$ref")
resolved = self._resolve_ref(ref_value, visited.copy(), depth)
if resolved is not None:
return self._resolve_object_refs(resolved, visited, depth + 1)
return obj
else:
return {k: self._resolve_object_refs(v, visited, depth + 1) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._resolve_object_refs(item, visited, depth + 1) for item in obj]
else:
return obj

def get_authorizers(self, event_type: str = Route.API) -> Dict[str, Authorizer]:
"""
Parse Swagger document and returns a list of Authorizer objects
Expand Down Expand Up @@ -97,14 +239,43 @@ def get_authorizers(self, event_type: str = Route.API) -> Dict[str, Authorizer]:
)

for auth_name, properties in authorizer_dict.items():
authorizer_object = properties.get(self._AUTHORIZER_KEY)
# Resolve $ref if the security scheme itself is a reference
resolved_properties = properties
if isinstance(properties, dict) and "$ref" in properties:
ref_value = properties.get("$ref")
LOG.debug("Resolving $ref in security scheme '%s': %s", auth_name, ref_value)
resolved_properties = self._resolve_ref(ref_value)
if not resolved_properties:
LOG.warning(
"Unable to resolve $ref '%s' for security scheme '%s', skipping",
ref_value,
auth_name,
)
continue

authorizer_object = (
resolved_properties.get(self._AUTHORIZER_KEY) if isinstance(resolved_properties, dict) else None
)

if not authorizer_object:
LOG.warning("Skip parsing unsupported authorizer '%s'", auth_name)
continue

# Resolve $ref if the authorizer object itself is a reference
if isinstance(authorizer_object, dict) and "$ref" in authorizer_object:
ref_value = authorizer_object.get("$ref")
LOG.debug("Resolving $ref in x-amazon-apigateway-authorizer for '%s': %s", auth_name, ref_value)
authorizer_object = self._resolve_ref(ref_value)
if not authorizer_object:
LOG.warning(
"Unable to resolve $ref '%s' for authorizer '%s', skipping",
ref_value,
auth_name,
)
continue

authorizer_type = authorizer_object.get(SwaggerParser._AUTHORIZER_TYPE, "").lower()
payload_version = authorizer_object.get(SwaggerParser._AUTHORIZER_PAYLOAD_VERSION)
payload_version: str = authorizer_object.get(SwaggerParser._AUTHORIZER_PAYLOAD_VERSION) or ""

if event_type == Route.HTTP and payload_version not in LambdaAuthorizer.PAYLOAD_VERSIONS:
raise InvalidSecurityDefinition(f"Authorizer '{auth_name}' contains an invalid payload version")
Expand All @@ -124,7 +295,7 @@ def get_authorizers(self, event_type: str = Route.API) -> Dict[str, Authorizer]:
continue

identity_sources = self._get_lambda_identity_sources(
auth_name, authorizer_type, event_type, properties, authorizer_object
auth_name, authorizer_type, event_type, resolved_properties, authorizer_object
)

validation_expression = authorizer_object.get(SwaggerParser._AUTHORIZER_LAMBDA_VALIDATION)
Expand Down Expand Up @@ -339,8 +510,36 @@ def get_routes(self, event_type=Route.API) -> List[Route]:
paths_dict = self.swagger.get("paths", {})

for full_path, path_config in paths_dict.items():
for method, method_config in path_config.items():
function_name = self._get_integration_function_name(method_config)
# Resolve $ref if path itself is a reference
resolved_path_config = path_config
if isinstance(path_config, dict) and "$ref" in path_config:
ref_value = path_config.get("$ref")
LOG.debug("Resolving $ref in path '%s': %s", full_path, ref_value)
resolved_path_config = self._resolve_ref(ref_value)
if not resolved_path_config:
LOG.debug("Unable to resolve $ref '%s' for path '%s', skipping", ref_value, full_path)
continue

if not isinstance(resolved_path_config, dict):
continue

for method, method_config in resolved_path_config.items():
# Resolve $ref if method config is a reference
resolved_method_config = method_config
if isinstance(method_config, dict) and "$ref" in method_config:
ref_value = method_config.get("$ref")
LOG.debug("Resolving $ref in method '%s %s': %s", method, full_path, ref_value)
resolved_method_config = self._resolve_ref(ref_value)
if not resolved_method_config:
LOG.debug(
"Unable to resolve $ref '%s' for method '%s %s', skipping",
ref_value,
method,
full_path,
)
continue

function_name = self._get_integration_function_name(resolved_method_config)
if not function_name:
LOG.debug(
"Lambda function integration not found in Swagger document at path='%s' method='%s'",
Expand All @@ -353,9 +552,9 @@ def get_routes(self, event_type=Route.API) -> List[Route]:
if normalized_method.lower() == self._ANY_METHOD_EXTENSION_KEY:
# Convert to a more commonly used method notation
normalized_method = self._ANY_METHOD
payload_format_version = self._get_payload_format_version(method_config)
payload_format_version = self._get_payload_format_version(resolved_method_config)

authorizers = method_config.get(SwaggerParser._SWAGGER_SECURITY, None)
authorizers = resolved_method_config.get(SwaggerParser._SWAGGER_SECURITY, None)

authorizer_name = None
use_default_authorizer = True
Expand Down Expand Up @@ -395,7 +594,7 @@ def get_routes(self, event_type=Route.API) -> List[Route]:
methods=[normalized_method],
event_type=event_type,
payload_format_version=payload_format_version,
operation_name=method_config.get("operationId"),
operation_name=resolved_method_config.get("operationId"),
stack_path=self.stack_path,
authorizer_name=authorizer_name,
use_default_authorizer=use_default_authorizer,
Expand All @@ -410,6 +609,8 @@ def _get_integration(self, method_config):
Integration configuration is defined under the special "x-amazon-apigateway-integration" key. We care only
about Lambda integrations, which are of type aws_proxy, and ignore the rest.

This method also resolves $ref references in the integration configuration.

Parameters
----------
method_config : dict
Expand All @@ -425,10 +626,23 @@ def _get_integration(self, method_config):

integration = method_config[self._INTEGRATION_KEY]

# Handle $ref in integration - resolve the reference to get the actual integration config
if isinstance(integration, dict) and "$ref" in integration:
ref_value = integration.get("$ref")
LOG.debug("Resolving $ref in x-amazon-apigateway-integration: %s", ref_value)
integration = self._resolve_ref(ref_value)
if not integration:
LOG.debug("Unable to resolve $ref '%s' in x-amazon-apigateway-integration", ref_value)
return None

# Get the integration type, checking that integration is a valid dict first
integration_type = integration.get("type") if isinstance(integration, dict) else None

if (
integration
and isinstance(integration, dict)
and integration.get("type").lower() == IntegrationType.aws_proxy.value
and integration_type
and integration_type.lower() == IntegrationType.aws_proxy.value
):
# Integration must be "aws_proxy" otherwise we don't care about it
return integration
Expand Down
Loading
Loading