Skip to content

[develop] Upgrade connexion to 2.15.0rc3, upgrade Werkzeug to >=3.0.3 to address CVE-2024-34069 #6932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ CHANGELOG
- Upgrade DCGM to version 4.2.3 (from 3.3.6) for all OSs except AL2.
- Upgrade Python to 3.12.11 (from 3.12.8) for all OSs except AL2.
- Upgrade Intel MPI Library to 2021.16.0 (from 2021.13.1).
- Upgrade Connexion to ~=2.15.0rc3 (from ~=2.13.0).
- Upgrade Werkzeug to ~=3.1 (from ~=2.0) to address [CVE-2024-34069](https://nvd.nist.gov/vuln/detail/cve-2024-34069).

**BUG FIXES**
- Fix an issue where Security Group validation failed when a rule contained both IPv4 ranges (IpRanges) and security group references (UserIdGroupPairs).
Expand Down
6 changes: 3 additions & 3 deletions cli/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ aws-cdk.core~=1.164
aws_cdk.aws-cloudwatch~=1.164
aws_cdk.aws-lambda~=1.164
boto3>=1.16.14
connexion~=2.13.0
flask>=2.2.5,<2.3
jinja2~=3.0
jmespath~=0.10
jsii==1.85.0
marshmallow~=3.10
PyYAML>=5.3.1,!=5.4
tabulate>=0.8.8,<=0.8.10
werkzeug~=2.0
connexion~=2.15.0rc3
werkzeug~=3.1
flask~=3.0
packaging~=25.0
7 changes: 4 additions & 3 deletions cli/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def readme():
"aws-cdk.aws-ssm~=" + CDK_VERSION,
"aws-cdk.aws-sqs~=" + CDK_VERSION,
"aws-cdk.aws-cloudformation~=" + CDK_VERSION,
"werkzeug~=2.0",
"connexion~=2.13.0",
"flask>=2.2.5,<2.3",
"connexion~=2.15.0rc3",
"jmespath~=0.10",
"jsii==1.85.0",
"werkzeug~=3.1",
"flask~=3.0",
"packaging~=25.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to add packaging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

packaging is in requirement.txt. I didn't add it. It was there but not in setup.py, so I added it here.

]

LAMBDA_REQUIRES = [
Expand Down
71 changes: 46 additions & 25 deletions cli/src/pcluster/api/awslambda/serverless_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import json
import os
import sys
from urllib.parse import unquote, unquote_plus, urlencode

from werkzeug.datastructures import Headers, MultiDict, iter_multi_items
from werkzeug.datastructures import Headers, iter_multi_items
from werkzeug.http import HTTP_STATUS_CODES
from werkzeug.urls import url_encode, url_unquote, url_unquote_plus
from werkzeug.wrappers import Response

# List of MIME types that should not be base64 encoded. MIME types within `text/*`
Expand Down Expand Up @@ -95,8 +95,8 @@ def encode_query_string(event):
if not params:
params = ""
if is_alb_event(event):
params = MultiDict((url_unquote_plus(k), url_unquote_plus(v)) for k, v in iter_multi_items(params))
return url_encode(params)
params = [(unquote_plus(k), unquote_plus(v)) for k, v in iter_multi_items(params)]
return urlencode(params, doseq=True)


def get_script_name(headers, request_context):
Expand All @@ -108,7 +108,7 @@ def get_script_name(headers, request_context):
"1",
]

if headers.get("Host", "").endswith(".amazonaws.com") and not strip_stage_path:
if "amazonaws.com" in headers.get("Host", "") and not strip_stage_path:

Check failure

Code scanning / CodeQL

Incomplete URL substring sanitization High

The string
amazonaws.com
may be at an arbitrary position in the sanitized URL.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why changing this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serverless_wsgi.py is a file we simply copy paste from https://github.com/logandk/serverless-wsgi/blob/master/serverless_wsgi.py.

script_name = "/{}".format(request_context.get("stage", ""))
else:
script_name = ""
Expand Down Expand Up @@ -138,7 +138,7 @@ def setup_environ_items(environ, headers):
def generate_response(response, event):
returndict = {"statusCode": response.status_code}

if "multiValueHeaders" in event:
if "multiValueHeaders" in event and event["multiValueHeaders"]:
returndict["multiValueHeaders"] = group_headers(response.headers)
else:
returndict["headers"] = split_headers(response.headers)
Expand All @@ -164,12 +164,27 @@ def generate_response(response, event):
return returndict


def strip_express_gateway_query_params(path):
"""Contrary to regular AWS lambda HTTP events, Express Gateway
(https://github.com/ExpressGateway/express-gateway-plugin-lambda)
adds query parameters to the path, which we need to strip.
"""
if "?" in path:
path = path.split("?")[0]
return path


def handle_request(app, event, context):
if event.get("source") in ["aws.events", "serverless-plugin-warmup"]:
print("Lambda warming event received, skipping handler")
return {}

if event.get("version") is None and event.get("isBase64Encoded") is None and not is_alb_event(event):
if (
event.get("version") is None
and event.get("isBase64Encoded") is None
and event.get("requestPath") is not None
and not is_alb_event(event)
):
return handle_lambda_integration(app, event, context)

if event.get("version") == "2.0":
Expand All @@ -179,7 +194,7 @@ def handle_request(app, event, context):


def handle_payload_v1(app, event, context):
if "multiValueHeaders" in event:
if "multiValueHeaders" in event and event["multiValueHeaders"]:
headers = Headers(event["multiValueHeaders"])
else:
headers = Headers(event["headers"])
Expand All @@ -189,35 +204,35 @@ def handle_payload_v1(app, event, context):
# If a user is using a custom domain on API Gateway, they may have a base
# path in their URL. This allows us to strip it out via an optional
# environment variable.
path_info = event["path"]
path_info = strip_express_gateway_query_params(event["path"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name) :] # noqa: E203

body = event["body"] or ""
body = event.get("body") or ""
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": encode_query_string(event),
"REMOTE_ADDR": event.get("requestContext", {}).get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
"REMOTE_USER": (event.get("requestContext", {}).get("authorizer") or {}).get("principalId", ""),
"REQUEST_METHOD": event.get("httpMethod", {}),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, why do we need to change the port and protocol?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serverless_wsgi.py is a file we simply copy paste from https://github.com/logandk/serverless-wsgi/blob/master/serverless_wsgi.py. Don't know why.

"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
Expand All @@ -237,31 +252,37 @@ def handle_payload_v2(app, event, context):

script_name = get_script_name(headers, event.get("requestContext", {}))

path_info = event["rawPath"]
path_info = strip_express_gateway_query_params(event["rawPath"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name) :] # noqa: E203

body = event.get("body", "")
body = get_body_bytes(event, body)

headers["Cookie"] = "; ".join(event.get("cookies", []))

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": event.get("rawQueryString", ""),
"REMOTE_ADDR": event.get("requestContext", {}).get("http", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
"REQUEST_METHOD": event.get("requestContext", {}).get("http", {}).get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
Expand All @@ -282,7 +303,7 @@ def handle_lambda_integration(app, event, context):

script_name = get_script_name(headers, event)

path_info = event["requestPath"]
path_info = strip_express_gateway_query_params(event["requestPath"])

for key, value in event.get("path", {}).items():
path_info = path_info.replace("{%s}" % key, value)
Expand All @@ -293,23 +314,23 @@ def handle_lambda_integration(app, event, context):
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"QUERY_STRING": url_encode(event.get("query", {})),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": urlencode(event.get("query", {}), doseq=True),
"REMOTE_ADDR": event.get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("principalId", ""),
"REQUEST_METHOD": event.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("enhancedAuthContext"),
"serverless.event": event,
Expand Down
28 changes: 25 additions & 3 deletions cli/src/pcluster/api/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
# Generated by OpenAPI Generator (python-flask)

import datetime
import json

import six
from connexion.apps.flask_app import FlaskJSONEncoder
from flask.json.provider import DefaultJSONProvider

from pcluster.api.models.base_model_ import Model
from pcluster.utils import to_iso_timestr


class JSONEncoder(FlaskJSONEncoder):
class JSONEncoder(json.JSONEncoder):
"""Make the model objects JSON serializable."""

include_nulls = False
Expand All @@ -35,4 +36,25 @@ def default(self, obj): # pylint: disable=arguments-renamed
return dikt
elif isinstance(obj, datetime.date):
return to_iso_timestr(obj)
return FlaskJSONEncoder.default(self, obj)
return json.JSONEncoder.default(self, obj)


class FlaskJSONEncoder(DefaultJSONProvider):
"""Make the model objects JSON serializable."""

include_nulls = False

def default(self, obj): # pylint: disable=arguments-renamed
"""Override the base method to add support for model objects serialization."""
if isinstance(obj, Model):
dikt = {}
for attr, _ in six.iteritems(obj.openapi_types):
value = getattr(obj, attr)
if value is None and not self.include_nulls:
continue
attr = obj.attribute_map[attr]
dikt[attr] = value
return dikt
elif isinstance(obj, datetime.date):
return to_iso_timestr(obj)
return super().default(obj)
2 changes: 1 addition & 1 deletion cli/src/pcluster/api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.

from connexion import ProblemException
from connexion.exceptions import ProblemException
from werkzeug.exceptions import HTTPException

from pcluster.api.models import (
Expand Down
9 changes: 5 additions & 4 deletions cli/src/pcluster/api/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import functools
import logging

import connexion
from connexion import ProblemException
from connexion.apps.flask_app import FlaskApp
from connexion.decorators.validation import ParameterValidator
from connexion.exceptions import ProblemException
from flask import Response, jsonify, request
from werkzeug.exceptions import HTTPException

Expand Down Expand Up @@ -74,9 +74,10 @@ def __init__(self, swagger_ui: bool = False, validate_responses=False):
assert_valid_node_js()
options = {"swagger_ui": swagger_ui}

self.app = connexion.FlaskApp(__name__, specification_dir="openapi/", skip_error_handlers=True)
self.app = FlaskApp(__name__, specification_dir="openapi/", skip_error_handlers=True)
self.flask_app = self.app.app
self.flask_app.json_encoder = encoder.JSONEncoder
self.flask_app.json_provider_class = encoder.FlaskJSONEncoder
self.flask_app.json = encoder.FlaskJSONEncoder(self.flask_app)
self.app.add_api(
"openapi.yaml",
arguments={"title": "ParallelCluster"},
Expand Down
Loading